warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (192) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +130 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +272 -104
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +770 -238
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_callable.py +34 -4
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/interop/example_jax_kernel.py +27 -1
  37. warp/examples/optim/example_drone.py +1 -1
  38. warp/examples/sim/example_cloth.py +1 -1
  39. warp/examples/sim/example_cloth_self_contact.py +48 -54
  40. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  41. warp/examples/tile/example_tile_cholesky.py +2 -1
  42. warp/examples/tile/example_tile_convolution.py +1 -1
  43. warp/examples/tile/example_tile_filtering.py +1 -1
  44. warp/examples/tile/example_tile_matmul.py +1 -1
  45. warp/examples/tile/example_tile_mlp.py +2 -0
  46. warp/fabric.py +7 -7
  47. warp/fem/__init__.py +5 -0
  48. warp/fem/adaptivity.py +1 -1
  49. warp/fem/cache.py +152 -63
  50. warp/fem/dirichlet.py +2 -2
  51. warp/fem/domain.py +136 -6
  52. warp/fem/field/field.py +141 -99
  53. warp/fem/field/nodal_field.py +85 -39
  54. warp/fem/field/virtual.py +99 -52
  55. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  56. warp/fem/geometry/closest_point.py +13 -0
  57. warp/fem/geometry/deformed_geometry.py +102 -40
  58. warp/fem/geometry/element.py +56 -2
  59. warp/fem/geometry/geometry.py +323 -22
  60. warp/fem/geometry/grid_2d.py +157 -62
  61. warp/fem/geometry/grid_3d.py +116 -20
  62. warp/fem/geometry/hexmesh.py +86 -20
  63. warp/fem/geometry/nanogrid.py +166 -86
  64. warp/fem/geometry/partition.py +59 -25
  65. warp/fem/geometry/quadmesh.py +86 -135
  66. warp/fem/geometry/tetmesh.py +47 -119
  67. warp/fem/geometry/trimesh.py +77 -270
  68. warp/fem/integrate.py +181 -95
  69. warp/fem/linalg.py +25 -58
  70. warp/fem/operator.py +124 -27
  71. warp/fem/quadrature/pic_quadrature.py +36 -14
  72. warp/fem/quadrature/quadrature.py +40 -16
  73. warp/fem/space/__init__.py +1 -1
  74. warp/fem/space/basis_function_space.py +66 -46
  75. warp/fem/space/basis_space.py +17 -4
  76. warp/fem/space/dof_mapper.py +1 -1
  77. warp/fem/space/function_space.py +2 -2
  78. warp/fem/space/grid_2d_function_space.py +4 -1
  79. warp/fem/space/hexmesh_function_space.py +4 -2
  80. warp/fem/space/nanogrid_function_space.py +3 -1
  81. warp/fem/space/partition.py +11 -2
  82. warp/fem/space/quadmesh_function_space.py +4 -1
  83. warp/fem/space/restriction.py +5 -2
  84. warp/fem/space/shape/__init__.py +10 -8
  85. warp/fem/space/tetmesh_function_space.py +4 -1
  86. warp/fem/space/topology.py +52 -21
  87. warp/fem/space/trimesh_function_space.py +4 -1
  88. warp/fem/utils.py +53 -8
  89. warp/jax.py +1 -2
  90. warp/jax_experimental/ffi.py +210 -67
  91. warp/jax_experimental/xla_ffi.py +37 -24
  92. warp/math.py +171 -1
  93. warp/native/array.h +103 -4
  94. warp/native/builtin.h +182 -35
  95. warp/native/coloring.cpp +6 -2
  96. warp/native/cuda_util.cpp +1 -1
  97. warp/native/exports.h +118 -63
  98. warp/native/intersect.h +5 -5
  99. warp/native/mat.h +8 -13
  100. warp/native/mathdx.cpp +11 -5
  101. warp/native/matnn.h +1 -123
  102. warp/native/mesh.h +1 -1
  103. warp/native/quat.h +34 -6
  104. warp/native/rand.h +7 -7
  105. warp/native/sparse.cpp +121 -258
  106. warp/native/sparse.cu +181 -274
  107. warp/native/spatial.h +305 -17
  108. warp/native/svd.h +23 -8
  109. warp/native/tile.h +603 -73
  110. warp/native/tile_radix_sort.h +1112 -0
  111. warp/native/tile_reduce.h +239 -13
  112. warp/native/tile_scan.h +240 -0
  113. warp/native/tuple.h +189 -0
  114. warp/native/vec.h +10 -20
  115. warp/native/warp.cpp +36 -4
  116. warp/native/warp.cu +588 -52
  117. warp/native/warp.h +47 -74
  118. warp/optim/linear.py +5 -1
  119. warp/paddle.py +7 -8
  120. warp/py.typed +0 -0
  121. warp/render/render_opengl.py +110 -80
  122. warp/render/render_usd.py +124 -62
  123. warp/sim/__init__.py +9 -0
  124. warp/sim/collide.py +253 -80
  125. warp/sim/graph_coloring.py +8 -1
  126. warp/sim/import_mjcf.py +4 -3
  127. warp/sim/import_usd.py +11 -7
  128. warp/sim/integrator.py +5 -2
  129. warp/sim/integrator_euler.py +1 -1
  130. warp/sim/integrator_featherstone.py +1 -1
  131. warp/sim/integrator_vbd.py +761 -322
  132. warp/sim/integrator_xpbd.py +1 -1
  133. warp/sim/model.py +265 -260
  134. warp/sim/utils.py +10 -7
  135. warp/sparse.py +303 -166
  136. warp/tape.py +54 -51
  137. warp/tests/cuda/test_conditional_captures.py +1046 -0
  138. warp/tests/cuda/test_streams.py +1 -1
  139. warp/tests/geometry/test_volume.py +2 -2
  140. warp/tests/interop/test_dlpack.py +9 -9
  141. warp/tests/interop/test_jax.py +0 -1
  142. warp/tests/run_coverage_serial.py +1 -1
  143. warp/tests/sim/disabled_kinematics.py +2 -2
  144. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  145. warp/tests/sim/test_collision.py +159 -51
  146. warp/tests/sim/test_coloring.py +91 -2
  147. warp/tests/test_array.py +254 -2
  148. warp/tests/test_array_reduce.py +2 -2
  149. warp/tests/test_assert.py +53 -0
  150. warp/tests/test_atomic_cas.py +312 -0
  151. warp/tests/test_codegen.py +142 -19
  152. warp/tests/test_conditional.py +47 -1
  153. warp/tests/test_ctypes.py +0 -20
  154. warp/tests/test_devices.py +8 -0
  155. warp/tests/test_fabricarray.py +4 -2
  156. warp/tests/test_fem.py +58 -25
  157. warp/tests/test_func.py +42 -1
  158. warp/tests/test_grad.py +1 -1
  159. warp/tests/test_lerp.py +1 -3
  160. warp/tests/test_map.py +481 -0
  161. warp/tests/test_mat.py +23 -24
  162. warp/tests/test_quat.py +28 -15
  163. warp/tests/test_rounding.py +10 -38
  164. warp/tests/test_runlength_encode.py +7 -7
  165. warp/tests/test_smoothstep.py +1 -1
  166. warp/tests/test_sparse.py +83 -2
  167. warp/tests/test_spatial.py +507 -1
  168. warp/tests/test_static.py +48 -0
  169. warp/tests/test_struct.py +2 -2
  170. warp/tests/test_tape.py +38 -0
  171. warp/tests/test_tuple.py +265 -0
  172. warp/tests/test_types.py +2 -2
  173. warp/tests/test_utils.py +24 -18
  174. warp/tests/test_vec.py +38 -408
  175. warp/tests/test_vec_constructors.py +325 -0
  176. warp/tests/tile/test_tile.py +438 -131
  177. warp/tests/tile/test_tile_mathdx.py +518 -14
  178. warp/tests/tile/test_tile_matmul.py +179 -0
  179. warp/tests/tile/test_tile_reduce.py +307 -5
  180. warp/tests/tile/test_tile_shared_memory.py +136 -7
  181. warp/tests/tile/test_tile_sort.py +121 -0
  182. warp/tests/unittest_suites.py +14 -6
  183. warp/types.py +462 -308
  184. warp/utils.py +647 -86
  185. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +189 -175
  187. warp/stubs.py +0 -3381
  188. warp/tests/sim/test_xpbd.py +0 -399
  189. warp/tests/test_mlp.py +0 -282
  190. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1112 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ #include "tile.h"
21
+
22
+ #if defined(__clang__)
23
+ // disable warnings related to C++17 extensions on CPU JIT builds
24
+ #pragma clang diagnostic push
25
+ #pragma clang diagnostic ignored "-Wc++17-extensions"
26
+ #endif
27
+
28
+ namespace wp
29
+ {
30
+
31
+
32
+ // After this threshold, using segmented_sort from cub is faster
33
+ // The threshold must be a power of 2
34
+ // The radix sort in this file is consistently slower than the bitonic sort
35
+ #define BITONIC_SORT_THRESHOLD 2048
36
+
37
+ struct UintKeyToUint
38
+ {
39
+ inline CUDA_CALLABLE uint32_t convert(uint32 value)
40
+ {
41
+ return value;
42
+ }
43
+
44
+ inline CUDA_CALLABLE uint32_t max_possible_key_value()
45
+ {
46
+ return 0xFFFFFFFF;
47
+ }
48
+ };
49
+
50
+ struct IntKeyToUint
51
+ {
52
+ inline CUDA_CALLABLE uint32_t convert(int value)
53
+ {
54
+ // Flip the sign bit: ensures negative numbers come before positive numbers
55
+ return static_cast<uint32_t>(value) ^ 0x80000000;
56
+ }
57
+
58
+ inline CUDA_CALLABLE int max_possible_key_value()
59
+ {
60
+ return 2147483647;
61
+ }
62
+ };
63
+
64
+ struct FloatKeyToUint
65
+ {
66
+ //http://stereopsis.com/radix.html
67
+ inline CUDA_CALLABLE uint32_t convert(float value)
68
+ {
69
+ unsigned int i = reinterpret_cast<unsigned int&>(value);
70
+ unsigned int mask = (unsigned int)(-(int)(i >> 31)) | 0x80000000;
71
+ return i ^ mask;
72
+ }
73
+
74
+ inline CUDA_CALLABLE float max_possible_key_value()
75
+ {
76
+ return FLT_MAX;
77
+ }
78
+ };
79
+
80
+
81
+ constexpr inline CUDA_CALLABLE bool is_power_of_two(int x)
82
+ {
83
+ return (x & (x - 1)) == 0;
84
+ }
85
+
86
+ constexpr inline CUDA_CALLABLE int next_higher_pow2(int input)
87
+ {
88
+ if (input <= 0) return 1; // Smallest power of 2 is 1
89
+
90
+ input--; // Decrement to handle already a power of 2 cases
91
+ input |= input >> 1;
92
+ input |= input >> 2;
93
+ input |= input >> 4;
94
+ input |= input >> 8;
95
+ input |= input >> 16;
96
+ input++; // Next power of 2
97
+
98
+ return input;
99
+ }
100
+
101
+
102
+ #if defined(__CUDA_ARCH__)
103
+
104
+
105
+ // Bitonic sort fast pass for small arrays
106
+
107
+ template<typename T>
108
+ inline CUDA_CALLABLE T shfl_xor(unsigned int thread_id, T* sh_mem, unsigned int lane_mask)
109
+ {
110
+ unsigned int source_lane = thread_id ^ lane_mask;
111
+ return sh_mem[source_lane];
112
+ }
113
+
114
+ template<typename K, typename V, int num_loops>
115
+ inline CUDA_CALLABLE void bitonic_sort_single_stage_full_thread_block(int k, unsigned int thread_id, unsigned int stride, K* key_sh_mem, V* val_sh_mem, int length, K max_key_value,
116
+ K* key_register, V* val_register)
117
+ {
118
+ __syncthreads();
119
+ #pragma unroll
120
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
121
+ {
122
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
123
+
124
+ key_register[loop_id] = thread_id2 < length ? key_sh_mem[thread_id2] : max_key_value;
125
+ val_register[loop_id] = thread_id2 < length ? val_sh_mem[thread_id2] : static_cast<V>(0);
126
+ }
127
+
128
+ __syncthreads();
129
+
130
+ K s_key[num_loops];
131
+ V s_val[num_loops];
132
+ bool swap[num_loops];
133
+ #pragma unroll
134
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
135
+ {
136
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
137
+
138
+ if(thread_id2 < length)
139
+ {
140
+ s_key[loop_id] = shfl_xor(thread_id2, key_sh_mem, stride);
141
+ s_val[loop_id] = shfl_xor(thread_id2, val_sh_mem, stride);
142
+ swap[loop_id] = (((thread_id2 & stride) != 0 ? key_register[loop_id] > s_key[loop_id] : key_register[loop_id] < s_key[loop_id])) ^ ((thread_id2 & k) == 0);
143
+ }
144
+ }
145
+
146
+ __syncthreads();
147
+
148
+ #pragma unroll
149
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
150
+ {
151
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
152
+ if (thread_id2 < length)
153
+ {
154
+ key_sh_mem[thread_id2] = swap[loop_id] ? s_key[loop_id] : key_register[loop_id];
155
+ val_sh_mem[thread_id2] = swap[loop_id] ? s_val[loop_id] : val_register[loop_id];
156
+ }
157
+ }
158
+ __syncthreads();
159
+ }
160
+
161
+ //stride can be 1, 2, 4, 8, 16
162
+ template<typename K, typename V>
163
+ inline CUDA_CALLABLE void bitonic_sort_single_stage_full_warp(int k, unsigned int thread_id, int stride, K& key, V& val)
164
+ {
165
+ auto s_key = __shfl_xor_sync(0xFFFFFFFFu, key, stride);
166
+ auto s_val = __shfl_xor_sync(0xFFFFFFFFu, val, stride);
167
+ auto swap = (((thread_id & stride) != 0 ? key > s_key : key < s_key)) ^ ((thread_id & k) == 0);
168
+ key = swap ? s_key : key;
169
+ val = swap ? s_val : val;
170
+ }
171
+
172
+
173
+ //Sorts 32 elements according to keys
174
+ template<typename K, typename V>
175
+ inline CUDA_CALLABLE void bitonic_sort_single_warp(unsigned int thread_id, K& key, V& val)
176
+ {
177
+ #pragma unroll
178
+ for (int k = 2; k <= 32; k <<= 1)
179
+ {
180
+ #pragma unroll
181
+ for (int stride = k / 2; stride > 0; stride >>= 1)
182
+ {
183
+ bitonic_sort_single_stage_full_warp(k, thread_id, stride, key, val);
184
+ }
185
+ }
186
+ }
187
+
188
+ template<typename K, typename V, typename KeyToUint>
189
+ inline CUDA_CALLABLE void bitonic_sort_single_warp(int thread_id,
190
+ K* keys_input,
191
+ V* values_input,
192
+ int num_elements_to_sort)
193
+ {
194
+ KeyToUint key_converter;
195
+
196
+ __syncwarp();
197
+
198
+ K key = thread_id < num_elements_to_sort ? keys_input[thread_id] : key_converter.max_possible_key_value();
199
+ V value;
200
+ if(thread_id < num_elements_to_sort)
201
+ value = values_input[thread_id];
202
+
203
+ __syncwarp();
204
+ bitonic_sort_single_warp(thread_id, key, value);
205
+ __syncwarp();
206
+
207
+ if(thread_id < num_elements_to_sort)
208
+ {
209
+ keys_input[thread_id] = key;
210
+ values_input[thread_id] = value;
211
+ }
212
+ __syncwarp();
213
+ }
214
+
215
+
216
+ //Sorts according to keys
217
+ template<int max_num_elements, typename K, typename V>
218
+ inline CUDA_CALLABLE void bitonic_sort_pow2_length(unsigned int thread_id, K* key_sh_mem, V* val_sh_mem, int length, K key_max_possible_value)
219
+ {
220
+ constexpr int num_loops = (max_num_elements + WP_TILE_BLOCK_DIM - 1) / WP_TILE_BLOCK_DIM;
221
+ K key[num_loops];
222
+ V val[num_loops];
223
+
224
+ #pragma unroll
225
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
226
+ {
227
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
228
+ key[loop_id] = thread_id2 < length ? key_sh_mem[thread_id2] : key_max_possible_value;
229
+ if (thread_id2 < length)
230
+ val[loop_id] = val_sh_mem[thread_id2];
231
+ }
232
+
233
+ __syncthreads();
234
+ bool full_block_sort_active = false;
235
+
236
+ for (int k = 2; k <= length; k <<= 1)
237
+ {
238
+ for (int stride = k / 2; stride > 0; stride >>= 1)
239
+ {
240
+ if (stride <= 16) //no inter-warp communication needed up to stride 16
241
+ {
242
+ if(full_block_sort_active)
243
+ {
244
+ __syncthreads();
245
+ #pragma unroll
246
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
247
+ {
248
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
249
+
250
+ //Switch from shared mem to registers
251
+ if (thread_id2 < length)
252
+ {
253
+ key[loop_id] = key_sh_mem[thread_id2];
254
+ val[loop_id] = val_sh_mem[thread_id2];
255
+ }
256
+ }
257
+ full_block_sort_active = false;
258
+ __syncthreads();
259
+ }
260
+
261
+ #pragma unroll
262
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
263
+ {
264
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
265
+ bitonic_sort_single_stage_full_warp(k, thread_id2, stride, key[loop_id], val[loop_id]);
266
+ }
267
+ }
268
+ else
269
+ {
270
+ if (!full_block_sort_active)
271
+ {
272
+ __syncthreads();
273
+ #pragma unroll
274
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
275
+ {
276
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
277
+
278
+ //Switch from registers t0 shared mem
279
+ if (thread_id2 < length)
280
+ {
281
+ key_sh_mem[thread_id2] = key[loop_id];
282
+ val_sh_mem[thread_id2] = val[loop_id];
283
+ }
284
+ }
285
+ full_block_sort_active = true;
286
+ __syncthreads();
287
+ }
288
+
289
+ bitonic_sort_single_stage_full_thread_block<K, V, num_loops>(k, thread_id, (unsigned int)stride, key_sh_mem, val_sh_mem, length, key_max_possible_value, key, val);
290
+ }
291
+ }
292
+ }
293
+
294
+ if (!full_block_sort_active)
295
+ {
296
+ #pragma unroll
297
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
298
+ {
299
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
300
+ //Switch from registers t0 shared mem
301
+ if (thread_id2 < length)
302
+ {
303
+ key_sh_mem[thread_id2] = key[loop_id];
304
+ val_sh_mem[thread_id2] = val[loop_id];
305
+ }
306
+ }
307
+ full_block_sort_active = true;
308
+ __syncthreads();
309
+ }
310
+ }
311
+
312
+ //Allocates shared memory to buffer the arrays that need to be sorted
313
+ template <int max_num_elements, typename K, typename V, typename KeyToUint>
314
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_shared_mem(
315
+ int thread_id,
316
+ K* keys_input,
317
+ V* values_input,
318
+ int num_elements_to_sort)
319
+ {
320
+ if constexpr(max_num_elements < 32)
321
+ {
322
+ //Fast track - single warp sort
323
+ if (thread_id < 32)
324
+ bitonic_sort_single_warp<K, V, KeyToUint>(thread_id, keys_input, values_input, num_elements_to_sort);
325
+ __syncthreads();
326
+ }
327
+ else
328
+ {
329
+ KeyToUint key_converter;
330
+ const K key_max_possible_value = key_converter.max_possible_key_value();
331
+
332
+ constexpr int shared_mem_count = next_higher_pow2(max_num_elements);
333
+
334
+ __shared__ K keys_shared_mem[shared_mem_count]; //TODO: This shared memory can be avoided if keys_input is already shared memory
335
+ __shared__ V values_shared_mem[shared_mem_count]; //TODO: This shared memory can be avoided if values_input is already shared memory
336
+
337
+ for(int i = thread_id; i < shared_mem_count; i += WP_TILE_BLOCK_DIM)
338
+ {
339
+ if (i < num_elements_to_sort)
340
+ {
341
+ keys_shared_mem[i] = keys_input[i];
342
+ values_shared_mem[i] = values_input[i];
343
+ }
344
+ else
345
+ {
346
+ // Note that these values may end up in the output If enough NaN or Inf values are present in keys_input
347
+ keys_shared_mem[i] = key_max_possible_value;
348
+ values_shared_mem[i] = static_cast<V>(0);
349
+ }
350
+ }
351
+ __syncthreads();
352
+
353
+ bitonic_sort_pow2_length<shared_mem_count, K, V>((unsigned int)thread_id, keys_shared_mem, values_shared_mem, shared_mem_count, key_max_possible_value);
354
+
355
+ __syncthreads();
356
+
357
+ for (int i = thread_id; i < num_elements_to_sort; i += WP_TILE_BLOCK_DIM)
358
+ {
359
+ keys_input[i] = keys_shared_mem[i];
360
+ values_input[i] = values_shared_mem[i];
361
+ }
362
+ __syncthreads();
363
+ }
364
+ }
365
+
366
+
367
+ // Specialization for int keys
368
+ template <int max_num_elements, typename V>
369
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_shared_mem(
370
+ int thread_id,
371
+ int* keys_input,
372
+ V* values_input,
373
+ int num_elements_to_sort)
374
+ {
375
+ bitonic_sort_thread_block_shared_mem<max_num_elements, int, V, IntKeyToUint>(
376
+ thread_id, keys_input, values_input, num_elements_to_sort);
377
+ }
378
+
379
+ // Specialization for unsigned int keys
380
+ template <int max_num_elements, typename V>
381
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_shared_mem(
382
+ int thread_id,
383
+ unsigned int* keys_input,
384
+ V* values_input,
385
+ int num_elements_to_sort)
386
+ {
387
+ bitonic_sort_thread_block_shared_mem<max_num_elements, unsigned int, V, UintKeyToUint>(
388
+ thread_id, keys_input, values_input, num_elements_to_sort);
389
+ }
390
+
391
+ // Specialization for float keys
392
+ template <int max_num_elements, typename V>
393
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_shared_mem(
394
+ int thread_id,
395
+ float* keys_input,
396
+ V* values_input,
397
+ int num_elements_to_sort)
398
+ {
399
+ bitonic_sort_thread_block_shared_mem<max_num_elements, float, V, FloatKeyToUint>(
400
+ thread_id, keys_input, values_input, num_elements_to_sort);
401
+ }
402
+
403
+
404
+
405
+ // Ideally keys_input and values_input point into fast memory (shared memory)
406
+ template <int max_num_elements, typename K, typename V, typename KeyToUint>
407
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_direct(
408
+ int thread_id,
409
+ K* keys_input,
410
+ V* values_input,
411
+ int num_elements_to_sort)
412
+ {
413
+ if constexpr(max_num_elements < 32)
414
+ {
415
+ //Fast track - single warp sort
416
+ if (thread_id < 32)
417
+ bitonic_sort_single_warp<K, V, KeyToUint>(thread_id, keys_input, values_input, num_elements_to_sort);
418
+ __syncthreads();
419
+ }
420
+ else
421
+ {
422
+ assert(num_elements_to_sort <= num_threads);
423
+
424
+ KeyToUint key_converter;
425
+ const K key_max_possible_value = key_converter.max_possible_key_value();
426
+
427
+ bitonic_sort_pow2_length<max_num_elements, K, V>((unsigned int)thread_id, keys_input, values_input, num_elements_to_sort, key_max_possible_value);
428
+ }
429
+ }
430
+
431
+ // Specialization for int keys
432
+ template <int max_num_elements, typename V>
433
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_direct(
434
+ int thread_id,
435
+ int* keys_input,
436
+ V* values_input,
437
+ int num_elements_to_sort)
438
+ {
439
+ bitonic_sort_thread_block_direct<max_num_elements, int, V, IntKeyToUint>(
440
+ thread_id, keys_input, values_input, num_elements_to_sort);
441
+ }
442
+
443
+ // Specialization for unsigned int keys
444
+ template <int max_num_elements, typename V>
445
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_direct(
446
+ int thread_id,
447
+ unsigned int* keys_input,
448
+ V* values_input,
449
+ int num_elements_to_sort)
450
+ {
451
+ bitonic_sort_thread_block_direct<max_num_elements, unsigned int, V, UintKeyToUint>(
452
+ thread_id, keys_input, values_input, num_elements_to_sort);
453
+ }
454
+
455
+ // Specialization for float keys
456
+ template <int max_num_elements, typename V>
457
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_direct(
458
+ int thread_id,
459
+ float* keys_input,
460
+ V* values_input,
461
+ int num_elements_to_sort)
462
+ {
463
+ bitonic_sort_thread_block_direct<max_num_elements, float, V, FloatKeyToUint>(
464
+ thread_id, keys_input, values_input, num_elements_to_sort);
465
+ }
466
+
467
+ // End bitonic sort
468
+
469
+ inline CUDA_CALLABLE int warp_scan_inclusive(int lane, unsigned int ballot_mask)
470
+ {
471
+ uint32_t mask = ((1u << (lane + 1)) - 1);
472
+ return __popc(ballot_mask & mask);
473
+ }
474
+
475
+ inline CUDA_CALLABLE int warp_scan_inclusive(int lane, unsigned int mask, bool thread_contributes_element)
476
+ {
477
+ return warp_scan_inclusive(lane, __ballot_sync(mask, thread_contributes_element));
478
+ }
479
+
480
+ template<typename T>
481
+ inline CUDA_CALLABLE T warp_scan_inclusive(int lane, T value)
482
+ {
483
+ //Computes an inclusive cumulative sum
484
+ #pragma unroll
485
+ for (int i = 1; i <= 32; i *= 2)
486
+ {
487
+ auto n = __shfl_up_sync(0xffffffffu, value, i, 32);
488
+
489
+ if (lane >= i)
490
+ value = value + n;
491
+ }
492
+ return value;
493
+ }
494
+
495
+ template<typename T>
496
+ inline CUDA_CALLABLE T warp_scan_exclusive(int lane, T value)
497
+ {
498
+ T scan = warp_scan_inclusive(lane, value);
499
+ return scan - value;
500
+ }
501
+
502
+ template <int num_warps, int num_threads, typename K, typename V, typename KeyToUint>
503
+ inline CUDA_CALLABLE void radix_sort_thread_block_core(
504
+ int thread_id,
505
+ K* keys_input, K* keys_tmp,
506
+ V* values_input, V* values_tmp,
507
+ int num_elements_to_sort)
508
+ {
509
+ KeyToUint key_converter;
510
+
511
+ int num_bits_to_sort = 32; //Sort all bits because that's what the bitonic fast pass does as well
512
+
513
+ const int warp_id = thread_id / 32;
514
+ const int lane_id = thread_id & 31;
515
+
516
+ const int bits_per_pass = 4; //Higher than 5 is currently not supported - 2^5=32 is the warp size and is still just fine
517
+ const int lowest_bits_mask = (1 << bits_per_pass) - 1;
518
+ const int num_scan_buckets = (1 << bits_per_pass);
519
+
520
+ const int num_warp_passes = (num_scan_buckets + num_warps - 1) / num_warps;
521
+
522
+ __shared__ int buckets[num_scan_buckets];
523
+ __shared__ int buckets2[num_scan_buckets];
524
+ __shared__ int buckets_cumulative_sum[num_scan_buckets];
525
+ __shared__ int shared_mem[num_warps][num_scan_buckets];
526
+
527
+ const int num_passes = (num_bits_to_sort + bits_per_pass - 1) / bits_per_pass;
528
+ const int num_inner_loops = (num_elements_to_sort + num_threads - 1) / num_threads;
529
+
530
+ for (int pass_id = 0; pass_id < num_passes; ++pass_id)
531
+ {
532
+ __syncthreads();
533
+ if (thread_id < num_scan_buckets)
534
+ {
535
+ buckets[lane_id] = 0;
536
+ buckets2[lane_id] = 0;
537
+ }
538
+ __syncthreads();
539
+
540
+ int shift = pass_id * bits_per_pass;
541
+
542
+ for (int j = thread_id; j < num_inner_loops * num_threads; j += num_threads)
543
+ {
544
+ int digit = j < num_elements_to_sort ? (int)((key_converter.convert(keys_input[j]) >> shift) & lowest_bits_mask) : num_scan_buckets;
545
+
546
+ for (int b = 0; b < num_scan_buckets; b++)
547
+ {
548
+ bool contributes = digit == b;
549
+ int sum_per_warp = warp_scan_inclusive(lane_id, 0xFFFFFFFF, contributes);
550
+
551
+ if (lane_id == 31)
552
+ shared_mem[warp_id][b] = sum_per_warp;
553
+ }
554
+ __syncthreads();
555
+
556
+ for(int b=warp_id;b< num_warp_passes * num_warps;b += num_warps)
557
+ {
558
+ int f = lane_id < num_warps ? shared_mem[lane_id][b] : 0;
559
+ f = warp_scan_inclusive(lane_id, f);
560
+ if (lane_id == 31)
561
+ buckets[b] += f;
562
+ }
563
+ __syncthreads();
564
+ }
565
+
566
+ #if VALIDATE_SORT
567
+ if (thread_id == 0)
568
+ {
569
+ for (int b = 0; b < num_scan_buckets; b++)
570
+ {
571
+ int bucket_sum = 0;
572
+ for (int j = 0; j < num_elements_to_sort; j++)
573
+ {
574
+ int digit = j < num_elements_to_sort ? (int)((key_converter.convert(keys_input[j]) >> shift) & lowest_bits_mask) : num_scan_buckets;
575
+ if (digit == b)
576
+ ++bucket_sum;
577
+ }
578
+ assert(buckets[b] == bucket_sum);
579
+ }
580
+ }
581
+ __syncthreads();
582
+ #endif
583
+
584
+ if (warp_id == 0)
585
+ {
586
+ int value = lane_id < num_scan_buckets ? buckets[lane_id] : 0;
587
+ value = warp_scan_exclusive(lane_id, value);
588
+ if (lane_id < num_scan_buckets)
589
+ buckets_cumulative_sum[lane_id] = value;
590
+
591
+ if (lane_id == num_scan_buckets - 1)
592
+ assert(debug + value == num_elements_to_sort);
593
+ }
594
+
595
+ __syncthreads();
596
+
597
+ #if VALIDATE_SORT
598
+ if(thread_id == 0)
599
+ {
600
+ for (int b = 0; b < num_scan_buckets; b++)
601
+ {
602
+ int bucket_sum = 0;
603
+ for(int j=0; j<num_elements_to_sort; j++)
604
+ {
605
+ int digit = j < num_elements_to_sort ? (int)((key_converter.convert(keys_input[j]) >> shift) & lowest_bits_mask) : num_scan_buckets;
606
+ if (digit == b)
607
+ ++bucket_sum;
608
+ }
609
+ assert(buckets[b] == bucket_sum);
610
+ }
611
+
612
+ int exclusive_bucket_sum = 0;
613
+ for (int b = 0; b < num_scan_buckets; b++)
614
+ {
615
+ assert(exclusive_bucket_sum == buckets_cumulative_sum[b]);
616
+ exclusive_bucket_sum += buckets[b];
617
+ }
618
+ assert(exclusive_bucket_sum == num_elements_to_sort);
619
+ }
620
+ __syncthreads();
621
+ #endif
622
+
623
+ //Now buckets holds numBuckets inclusive cumulative sums (e. g. 16 sums for 4 bit radix sort - 2^4=16)
624
+ //The problem is that we either store local_offset_per_thread for every element array (potentially many) or we recompute it again
625
+ for (int j = thread_id; j < num_inner_loops * num_threads; j += num_threads)
626
+ {
627
+ int digit = j < num_elements_to_sort ? (int)((key_converter.convert(keys_input[j]) >> shift) & lowest_bits_mask) : num_scan_buckets;
628
+
629
+ int local_offset_per_thread = 0;
630
+
631
+ for (int b = 0; b < num_scan_buckets; b++)
632
+ {
633
+ bool contributes = digit == b;
634
+ int sum_per_warp = warp_scan_inclusive(lane_id, 0xFFFFFFFF, contributes);
635
+ if (lane_id == 31)
636
+ shared_mem[warp_id][b] = sum_per_warp;
637
+
638
+ if (contributes)
639
+ local_offset_per_thread = sum_per_warp - 1; //-1 because of inclusive scan and local_offset_per_thread needs exclusive scan
640
+ }
641
+
642
+ for (int b = 0; b < num_scan_buckets; b++)
643
+ {
644
+ __syncthreads();
645
+ int global_offset = buckets2[b];
646
+ __syncthreads();
647
+
648
+ int f = lane_id < num_warps ? shared_mem[lane_id][b] : 0;
649
+ int inclusive_scan = warp_scan_inclusive(lane_id, f);
650
+ if (lane_id == 31 && warp_id == 0)
651
+ {
652
+ buckets2[b] += inclusive_scan;
653
+ }
654
+
655
+ int warp_offset = __shfl_sync(0xFFFFFFFF, inclusive_scan - f, warp_id); //-f because warp_offset needs to be an exclusive scan
656
+
657
+ bool contributes = digit == b;
658
+ if (contributes)
659
+ {
660
+ local_offset_per_thread += global_offset + warp_offset;
661
+
662
+ #if VALIDATE_SORT
663
+ int curr = buckets_cumulative_sum[b];
664
+ int next = b + 1 < num_scan_buckets ? buckets_cumulative_sum[b + 1] : num_elements_to_sort;
665
+ assert(local_offset_per_thread < next - curr && local_offset_per_thread >= 0);
666
+ #endif
667
+ }
668
+ }
669
+ __syncthreads();
670
+
671
+ if (j < num_elements_to_sort)
672
+ {
673
+ int final_offset = buckets_cumulative_sum[digit] + local_offset_per_thread;
674
+
675
+ keys_tmp[final_offset] = keys_input[j];
676
+ values_tmp[final_offset] = values_input[j];
677
+ }
678
+ }
679
+
680
+ __syncthreads();
681
+
682
+ #if VALIDATE_SORT
683
+ for (int j = thread_id; j < num_inner_loops * num_threads; j += num_threads)
684
+ {
685
+ if(j>0 && j < num_elements_to_sort)
686
+ {
687
+ int digit1 = (int)((keys_tmp[j-1] >> shift) & lowest_bits_mask);
688
+ int digit2 = (int)((keys_tmp[j] >> shift) & lowest_bits_mask);
689
+
690
+ assert(digit1<=digit2);
691
+ }
692
+ }
693
+ __syncthreads();
694
+ #endif
695
+
696
+ auto tmp = keys_tmp;
697
+ keys_tmp = keys_input;
698
+ keys_input = tmp;
699
+
700
+ auto tmp2 = values_tmp;
701
+ values_tmp = values_input;
702
+ values_input = tmp2;
703
+ }
704
+
705
+ //For odd number of passes, the result is the const& wrong array - copy it over
706
+ if (num_passes % 2 != 0)
707
+ {
708
+ for (int j = thread_id; j < num_inner_loops * num_threads; j += num_threads)
709
+ {
710
+ if (j < num_elements_to_sort)
711
+ {
712
+ keys_tmp[j] = keys_input[j];
713
+ values_tmp[j] = values_input[j];
714
+ }
715
+ }
716
+
717
+ auto tmp = keys_tmp;
718
+ keys_tmp = keys_input;
719
+ keys_input = tmp;
720
+
721
+ auto tmp2 = values_tmp;
722
+ values_tmp = values_input;
723
+ values_input = tmp2;
724
+ }
725
+ }
726
+
727
+
728
+
729
+
730
+ template <int num_warps, int num_threads, typename V>
731
+ inline CUDA_CALLABLE void radix_sort_thread_block(
732
+ int thread_id,
733
+ int* keys_input, int* keys_tmp,
734
+ V* values_input, V* values_tmp,
735
+ int num_elements_to_sort)
736
+ {
737
+ radix_sort_thread_block_core<num_warps, num_threads, int, V, IntKeyToUint>(
738
+ thread_id, keys_input, keys_tmp,
739
+ values_input, values_tmp, num_elements_to_sort);
740
+ }
741
+
742
+ template <int num_warps, int num_threads, typename V>
743
+ inline CUDA_CALLABLE void radix_sort_thread_block(
744
+ int thread_id,
745
+ unsigned int* keys_input, unsigned int* keys_tmp,
746
+ V* values_input, V* values_tmp,
747
+ int num_elements_to_sort)
748
+ {
749
+ radix_sort_thread_block_core<num_warps, num_threads, unsigned int, V, UintKeyToUint>(
750
+ thread_id, keys_input, keys_tmp,
751
+ values_input, values_tmp,
752
+ num_elements_to_sort);
753
+ }
754
+
755
+ template <int num_warps, int num_threads, typename V>
756
+ inline CUDA_CALLABLE void radix_sort_thread_block(
757
+ int thread_id,
758
+ float* keys_input, float* keys_tmp,
759
+ V* values_input, V* values_tmp,
760
+ int num_elements_to_sort)
761
+ {
762
+ radix_sort_thread_block_core<num_warps, num_threads, float, V, FloatKeyToUint>(
763
+ thread_id, keys_input, keys_tmp,
764
+ values_input, values_tmp,
765
+ num_elements_to_sort);
766
+ }
767
+
768
+
769
+ template <typename TileK, typename TileV>
770
+ void tile_sort(TileK& t, TileV& t2)
771
+ {
772
+ using T = typename TileK::Type;
773
+ using V = typename TileV::Type;
774
+
775
+ constexpr int num_elements_to_sort = TileK::Layout::Shape::size();
776
+ T* keys = &t.data(0);
777
+ V* values = &t2.data(0);
778
+
779
+ //Trim away the code that won't be used - possible because the number of elements to sort is known at compile time
780
+ if constexpr (num_elements_to_sort <= BITONIC_SORT_THRESHOLD)
781
+ {
782
+ if constexpr(is_power_of_two(num_elements_to_sort))
783
+ bitonic_sort_thread_block_direct<num_elements_to_sort, V>(WP_TILE_THREAD_IDX, keys, values, num_elements_to_sort);
784
+ else
785
+ bitonic_sort_thread_block_shared_mem<num_elements_to_sort, V>(WP_TILE_THREAD_IDX, keys, values, num_elements_to_sort);
786
+ }
787
+ else
788
+ {
789
+ __shared__ T keys_tmp[num_elements_to_sort];
790
+ __shared__ V values_tmp[num_elements_to_sort];
791
+
792
+ constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1)/WP_TILE_WARP_SIZE;
793
+
794
+ radix_sort_thread_block<warp_count, WP_TILE_BLOCK_DIM, V>(WP_TILE_THREAD_IDX, keys, keys_tmp,
795
+ values, values_tmp, num_elements_to_sort);
796
+ }
797
+
798
+ WP_TILE_SYNC();
799
+ }
800
+
801
+ template <typename TileK, typename TileV>
802
+ void tile_sort(TileK& t, TileV& t2, int start, int length)
803
+ {
804
+ using T = typename TileK::Type;
805
+ using V = typename TileV::Type;
806
+
807
+ constexpr int max_elements_to_sort = TileK::Layout::Shape::size();
808
+ const int num_elements_to_sort = length;
809
+ T* keys = &t.data(start);
810
+ V* values = &t2.data(start);
811
+
812
+ if (num_elements_to_sort <= BITONIC_SORT_THRESHOLD)
813
+ {
814
+ if (is_power_of_two(num_elements_to_sort))
815
+ bitonic_sort_thread_block_direct<max_elements_to_sort, V>(WP_TILE_THREAD_IDX, keys, values, num_elements_to_sort);
816
+ else
817
+ bitonic_sort_thread_block_shared_mem<max_elements_to_sort, V>(WP_TILE_THREAD_IDX, keys, values, num_elements_to_sort);
818
+ }
819
+ else
820
+ {
821
+ if constexpr (max_elements_to_sort > BITONIC_SORT_THRESHOLD)
822
+ {
823
+ __shared__ T keys_tmp[max_elements_to_sort];
824
+ __shared__ V values_tmp[max_elements_to_sort];
825
+
826
+ constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1)/WP_TILE_WARP_SIZE;
827
+
828
+ radix_sort_thread_block<warp_count, WP_TILE_BLOCK_DIM, V>(WP_TILE_THREAD_IDX, keys, keys_tmp,
829
+ values, values_tmp, num_elements_to_sort);
830
+ }
831
+ }
832
+
833
+ WP_TILE_SYNC();
834
+ }
835
+
836
+ #else
837
+
838
+ // CPU implementation
839
+
840
+ template <typename K>
841
+ void swap_elements(K& a, K& b)
842
+ {
843
+ K tmp = a;
844
+ a = b;
845
+ b = tmp;
846
+ }
847
+
848
+ // length must be a power of two
849
+ template <typename K, typename V>
850
+ void bitonic_sort_pairs_pow2_length_cpu(K* keys, V* values, int length)
851
+ {
852
+ for (int k = 2; k <= length; k *= 2)
853
+ {
854
+ for (int stride = k / 2; stride > 0; stride /= 2)
855
+ {
856
+ for (int i = 0; i < length; i++)
857
+ {
858
+ int swap_idx = i ^ stride;
859
+ if (swap_idx > i)
860
+ {
861
+ bool ascending = ((i & k) == 0);
862
+ if ((ascending && keys[i] > keys[swap_idx]) || (!ascending && keys[i] < keys[swap_idx]))
863
+ {
864
+ swap_elements(keys[i], keys[swap_idx]);
865
+ swap_elements(values[i], values[swap_idx]);
866
+ }
867
+ }
868
+ }
869
+ }
870
+ }
871
+ }
872
+
873
+ template <typename K, typename V, int max_size, typename KeyToUint>
874
+ void bitonic_sort_pairs_general_size_cpu(K* keys, V* values, int length)
875
+ {
876
+ constexpr int pow2_size = next_higher_pow2(max_size);
877
+
878
+ K keys_tmp[pow2_size];
879
+ V values_tmp[pow2_size];
880
+
881
+ KeyToUint converter;
882
+ K max_key = converter.max_possible_key_value();
883
+
884
+ for(int i=0; i<pow2_size; ++i)
885
+ {
886
+ keys_tmp[i] = i < length ? keys[i] : max_key;
887
+ if(i < length)
888
+ values_tmp[i] = values[i];
889
+ }
890
+
891
+ bitonic_sort_pairs_pow2_length_cpu(keys_tmp, values_tmp, pow2_size);
892
+
893
+ for(int i=0; i<length; ++i)
894
+ {
895
+ keys[i] = keys_tmp[i];
896
+ values[i] = values_tmp[i];
897
+ }
898
+ }
899
+
900
+ template <typename V, int max_size>
901
+ void bitonic_sort_pairs_general_size_cpu(unsigned int* keys, V* values, int length)
902
+ {
903
+ bitonic_sort_pairs_general_size_cpu<unsigned int, V, max_size, UintKeyToUint>(keys, values, length);
904
+ }
905
+
906
+ template <typename V, int max_size>
907
+ void bitonic_sort_pairs_general_size_cpu(int* keys, V* values, int length)
908
+ {
909
+ bitonic_sort_pairs_general_size_cpu<int, V, max_size, IntKeyToUint>(keys, values, length);
910
+ }
911
+
912
+ template <typename V, int max_size>
913
+ void bitonic_sort_pairs_general_size_cpu(float* keys, V* values, int length)
914
+ {
915
+ bitonic_sort_pairs_general_size_cpu<float, V, max_size, FloatKeyToUint>(keys, values, length);
916
+ }
917
+
918
+
919
+
920
+ template <typename K, typename V, typename KeyToUint>
921
+ void radix_sort_pairs_cpu_core(K* keys, K* aux_keys, V* values, V* aux_values, int n)
922
+ {
923
+ KeyToUint converter;
924
+ static unsigned int tables[2][1 << 16];
925
+ memset(tables, 0, sizeof(tables));
926
+
927
+ // build histograms
928
+ for (int i=0; i < n; ++i)
929
+ {
930
+ const unsigned int k = converter.convert(keys[i]);
931
+ const unsigned short low = k & 0xffff;
932
+ const unsigned short high = k >> 16;
933
+
934
+ ++tables[0][low];
935
+ ++tables[1][high];
936
+ }
937
+
938
+ // convert histograms to offset tables in-place
939
+ unsigned int offlow = 0;
940
+ unsigned int offhigh = 0;
941
+
942
+ for (int i=0; i < 65536; ++i)
943
+ {
944
+ const unsigned int newofflow = offlow + tables[0][i];
945
+ const unsigned int newoffhigh = offhigh + tables[1][i];
946
+
947
+ tables[0][i] = offlow;
948
+ tables[1][i] = offhigh;
949
+
950
+ offlow = newofflow;
951
+ offhigh = newoffhigh;
952
+ }
953
+
954
+ // pass 1 - sort by low 16 bits
955
+ for (int i=0; i < n; ++i)
956
+ {
957
+ // lookup offset of input
958
+ const K f = keys[i];
959
+ const unsigned int k = converter.convert(f);
960
+ const V v = values[i];
961
+ const unsigned int b = k & 0xffff;
962
+
963
+ // find offset and increment
964
+ const unsigned int offset = tables[0][b]++;
965
+
966
+ aux_keys[offset] = f;
967
+ aux_values[offset] = v;
968
+ }
969
+
970
+ // pass 2 - sort by high 16 bits
971
+ for (int i=0; i < n; ++i)
972
+ {
973
+ // lookup offset of input
974
+ const K f = aux_keys[i];
975
+ const unsigned int k = converter.convert(f);
976
+ const V v = aux_values[i];
977
+
978
+ const unsigned int b = k >> 16;
979
+
980
+ const unsigned int offset = tables[1][b]++;
981
+
982
+ keys[offset] = f;
983
+ values[offset] = v;
984
+ }
985
+ }
986
+
987
+ template <typename V>
988
+ inline void radix_sort_pairs_cpu(
989
+ int* keys_input,
990
+ int* keys_aux,
991
+ V* values_input,
992
+ V* values_aux,
993
+ int num_elements_to_sort)
994
+ {
995
+ radix_sort_pairs_cpu_core<int, V, IntKeyToUint>(
996
+ keys_input, keys_aux,
997
+ values_input, values_aux,
998
+ num_elements_to_sort);
999
+ }
1000
+
1001
+ template <typename V>
1002
+ inline void radix_sort_pairs_cpu(
1003
+ unsigned int* keys_input,
1004
+ unsigned int* keys_aux,
1005
+ V* values_input,
1006
+ V* values_aux,
1007
+ int num_elements_to_sort)
1008
+ {
1009
+ radix_sort_pairs_cpu_core<unsigned int, V, UintKeyToUint>(
1010
+ keys_input, keys_aux,
1011
+ values_input, values_aux,
1012
+ num_elements_to_sort);
1013
+ }
1014
+
1015
+ template <typename V>
1016
+ inline void radix_sort_pairs_cpu(
1017
+ float* keys_input,
1018
+ float* keys_aux,
1019
+ V* values_input,
1020
+ V* values_aux,
1021
+ int num_elements_to_sort)
1022
+ {
1023
+ radix_sort_pairs_cpu_core<float, V, FloatKeyToUint>(
1024
+ keys_input, keys_aux,
1025
+ values_input, values_aux,
1026
+ num_elements_to_sort);
1027
+ }
1028
+
1029
+
1030
+
1031
+ template <typename TileK, typename TileV>
1032
+ void tile_sort(TileK& t, TileV& t2)
1033
+ {
1034
+ using T = typename TileK::Type;
1035
+ using V = typename TileV::Type;
1036
+
1037
+ constexpr int num_elements_to_sort = TileK::Layout::Shape::size();
1038
+ T* keys = &t.data(0);
1039
+ V* values = &t2.data(0);
1040
+
1041
+ //Trim away the code that won't be used - possible because the number of elements to sort is known at compile time
1042
+ if constexpr (num_elements_to_sort <= BITONIC_SORT_THRESHOLD)
1043
+ {
1044
+ if constexpr(is_power_of_two(num_elements_to_sort))
1045
+ bitonic_sort_pairs_pow2_length_cpu<T, V>(keys, values, num_elements_to_sort);
1046
+ else
1047
+ bitonic_sort_pairs_general_size_cpu<V, num_elements_to_sort>(keys, values, num_elements_to_sort);
1048
+ }
1049
+ else
1050
+ {
1051
+ T keys_tmp[num_elements_to_sort];
1052
+ V values_tmp[num_elements_to_sort];
1053
+
1054
+ radix_sort_pairs_cpu<V>(keys, keys_tmp, values, values_tmp, num_elements_to_sort);
1055
+ }
1056
+
1057
+ WP_TILE_SYNC();
1058
+ }
1059
+
1060
+ template <typename TileK, typename TileV>
1061
+ void tile_sort(TileK& t, TileV& t2, int start, int length)
1062
+ {
1063
+ using T = typename TileK::Type;
1064
+ using V = typename TileV::Type;
1065
+
1066
+ constexpr int max_elements_to_sort = TileK::Layout::Shape::size();
1067
+ const int num_elements_to_sort = length;
1068
+ T* keys = &t.data(start);
1069
+ V* values = &t2.data(start);
1070
+
1071
+ if (num_elements_to_sort <= BITONIC_SORT_THRESHOLD)
1072
+ {
1073
+ if (is_power_of_two(num_elements_to_sort))
1074
+ bitonic_sort_pairs_pow2_length_cpu<T, V>(keys, values, num_elements_to_sort);
1075
+ else
1076
+ bitonic_sort_pairs_general_size_cpu<V, max_elements_to_sort>(keys, values, num_elements_to_sort);
1077
+ }
1078
+ else
1079
+ {
1080
+ if constexpr (max_elements_to_sort > BITONIC_SORT_THRESHOLD)
1081
+ {
1082
+ T keys_tmp[max_elements_to_sort];
1083
+ V values_tmp[max_elements_to_sort];
1084
+
1085
+ radix_sort_pairs_cpu<V>(keys, keys_tmp, values, values_tmp, num_elements_to_sort);
1086
+ }
1087
+ }
1088
+
1089
+ WP_TILE_SYNC();
1090
+ }
1091
+
1092
+
1093
+ #endif // !defined(__CUDA_ARCH__)
1094
+
1095
+
1096
+ template <typename TileK, typename TileV>
1097
+ inline void adj_tile_sort(TileK& t, TileV& t2, TileK& adj_t1, TileV& adj_t2)
1098
+ {
1099
+ // todo: general purpose sort gradients not implemented
1100
+ }
1101
+
1102
+ template <typename TileK, typename TileV>
1103
+ inline void adj_tile_sort(TileK& t, TileV& t2, int start, int length, TileK& adj_t1, TileV& adj_t2, int adj_start, int adj_length)
1104
+ {
1105
+ // todo: general purpose sort gradients not implemented
1106
+ }
1107
+
1108
+ } // namespace wp
1109
+
1110
+ #if defined(__clang__)
1111
+ #pragma clang diagnostic pop
1112
+ #endif