faiss 0.2.0 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -6,542 +6,554 @@
6
6
  */
7
7
 
8
8
  #include <faiss/gpu/impl/InterleavedCodes.h>
9
- #include <faiss/impl/FaissAssert.h>
10
9
  #include <faiss/gpu/utils/StaticUtils.h>
10
+ #include <faiss/impl/FaissAssert.h>
11
11
 
12
- namespace faiss { namespace gpu {
12
+ namespace faiss {
13
+ namespace gpu {
13
14
 
14
15
  inline uint8_t unpack5(int i, uint8_t vLower, uint8_t vUpper) {
15
- uint8_t v = 0;
16
-
17
- // lsb ... msb
18
- // 0: 0 0 0 0 0 1 1 1
19
- // 1: 1 1 2 2 2 2 2 3
20
- // 2: 3 3 3 3 4 4 4 4
21
- // 3: 4 5 5 5 5 5 6 6
22
- // 4: 6 6 6 7 7 7 7 7
23
- switch (i % 8) {
24
- case 0:
25
- // 5 lsbs of lower
26
- v = vLower & 0x1f;
27
- break;
28
- case 1:
29
- // 3 msbs of lower as v lsbs
30
- // 2 msbs of upper as v msbs
31
- v = (vLower >> 5) | ((vUpper & 0x3) << 3);
32
- break;
33
- case 2:
34
- // 5 of lower
35
- v = (vLower >> 2) & 0x1f;
36
- break;
37
- case 3:
38
- // 1 msbs of lower as v lsbs
39
- // 4 lsbs of upper as v msbs
40
- v = (vLower >> 7) | ((vUpper & 0xf) << 1);
41
- break;
42
- case 4:
43
- // 4 msbs of lower as v lsbs
44
- // 1 lsbs of upper as v msbs
45
- v = (vLower >> 4) | ((vUpper & 0x1) << 4);
46
- break;
47
- case 5:
48
- // 5 of lower
49
- v = (vLower >> 1) & 0x1f;
50
- break;
51
- case 6:
52
- // 2 msbs of lower as v lsbs
53
- // 3 lsbs of upper as v msbs
54
- v = (vLower >> 6) | ((vUpper & 0x7) << 2);
55
- break;
56
- case 7:
57
- // 5 of lower
58
- v = (vLower >> 3);
59
- break;
60
- }
61
-
62
- return v;
63
- }
16
+ uint8_t v = 0;
17
+
18
+ // lsb ... msb
19
+ // 0: 0 0 0 0 0 1 1 1
20
+ // 1: 1 1 2 2 2 2 2 3
21
+ // 2: 3 3 3 3 4 4 4 4
22
+ // 3: 4 5 5 5 5 5 6 6
23
+ // 4: 6 6 6 7 7 7 7 7
24
+ switch (i % 8) {
25
+ case 0:
26
+ // 5 lsbs of lower
27
+ v = vLower & 0x1f;
28
+ break;
29
+ case 1:
30
+ // 3 msbs of lower as v lsbs
31
+ // 2 msbs of upper as v msbs
32
+ v = (vLower >> 5) | ((vUpper & 0x3) << 3);
33
+ break;
34
+ case 2:
35
+ // 5 of lower
36
+ v = (vLower >> 2) & 0x1f;
37
+ break;
38
+ case 3:
39
+ // 1 msbs of lower as v lsbs
40
+ // 4 lsbs of upper as v msbs
41
+ v = (vLower >> 7) | ((vUpper & 0xf) << 1);
42
+ break;
43
+ case 4:
44
+ // 4 msbs of lower as v lsbs
45
+ // 1 lsbs of upper as v msbs
46
+ v = (vLower >> 4) | ((vUpper & 0x1) << 4);
47
+ break;
48
+ case 5:
49
+ // 5 of lower
50
+ v = (vLower >> 1) & 0x1f;
51
+ break;
52
+ case 6:
53
+ // 2 msbs of lower as v lsbs
54
+ // 3 lsbs of upper as v msbs
55
+ v = (vLower >> 6) | ((vUpper & 0x7) << 2);
56
+ break;
57
+ case 7:
58
+ // 5 of lower
59
+ v = (vLower >> 3);
60
+ break;
61
+ }
64
62
 
65
- inline uint8_t unpack6(int i, uint8_t vLower, uint8_t vUpper) {
66
- uint8_t v = 0;
67
-
68
- switch (i % 4) {
69
- case 0:
70
- // 6 lsbs of lower
71
- v = vLower & 0x3f;
72
- break;
73
- case 1:
74
- // 2 msbs of lower as v lsbs
75
- // 4 lsbs of upper as v msbs
76
- v = (vLower >> 6) | ((vUpper & 0xf) << 2);
77
- break;
78
- case 2:
79
- // 4 msbs of lower as v lsbs
80
- // 2 lsbs of upper as v msbs
81
- v = (vLower >> 4) | ((vUpper & 0x3) << 4);
82
- break;
83
- case 3:
84
- // 6 msbs of lower
85
- v = (vLower >> 2);
86
- break;
87
- }
88
-
89
- return v;
63
+ return v;
90
64
  }
91
65
 
66
+ inline uint8_t unpack6(int i, uint8_t vLower, uint8_t vUpper) {
67
+ uint8_t v = 0;
68
+
69
+ switch (i % 4) {
70
+ case 0:
71
+ // 6 lsbs of lower
72
+ v = vLower & 0x3f;
73
+ break;
74
+ case 1:
75
+ // 2 msbs of lower as v lsbs
76
+ // 4 lsbs of upper as v msbs
77
+ v = (vLower >> 6) | ((vUpper & 0xf) << 2);
78
+ break;
79
+ case 2:
80
+ // 4 msbs of lower as v lsbs
81
+ // 2 lsbs of upper as v msbs
82
+ v = (vLower >> 4) | ((vUpper & 0x3) << 4);
83
+ break;
84
+ case 3:
85
+ // 6 msbs of lower
86
+ v = (vLower >> 2);
87
+ break;
88
+ }
92
89
 
93
- std::vector<uint8_t>
94
- unpackNonInterleaved(std::vector<uint8_t> data,
95
- int numVecs,
96
- int dims,
97
- int bitsPerCode) {
98
- int srcVecSize = utils::divUp(dims * bitsPerCode, 8);
99
- FAISS_ASSERT(data.size() == numVecs * srcVecSize);
90
+ return v;
91
+ }
100
92
 
101
- if (bitsPerCode == 8 ||
102
- bitsPerCode == 16 ||
103
- bitsPerCode == 32) {
104
- // nothing to do
105
- return data;
106
- }
93
+ std::vector<uint8_t> unpackNonInterleaved(
94
+ std::vector<uint8_t> data,
95
+ int numVecs,
96
+ int dims,
97
+ int bitsPerCode) {
98
+ int srcVecSize = utils::divUp(dims * bitsPerCode, 8);
99
+ FAISS_ASSERT(data.size() == numVecs * srcVecSize);
100
+
101
+ if (bitsPerCode == 8 || bitsPerCode == 16 || bitsPerCode == 32) {
102
+ // nothing to do
103
+ return data;
104
+ }
107
105
 
108
- // bit codes padded to whole bytes
109
- std::vector<uint8_t> out(numVecs * dims * utils::divUp(bitsPerCode, 8));
106
+ // bit codes padded to whole bytes
107
+ std::vector<uint8_t> out(numVecs * dims * utils::divUp(bitsPerCode, 8));
110
108
 
111
- if (bitsPerCode == 4) {
109
+ if (bitsPerCode == 4) {
112
110
  #pragma omp parallel for
113
- for (int i = 0; i < numVecs; ++i) {
114
- for (int j = 0; j < dims; ++j) {
115
- int srcIdx = i * srcVecSize + (j / 2);
116
- FAISS_ASSERT(srcIdx < data.size());
111
+ for (int i = 0; i < numVecs; ++i) {
112
+ for (int j = 0; j < dims; ++j) {
113
+ int srcIdx = i * srcVecSize + (j / 2);
114
+ FAISS_ASSERT(srcIdx < data.size());
117
115
 
118
- uint8_t v = data[srcIdx];
119
- v = (j % 2 == 0) ? v & 0xf : v >> 4;
116
+ uint8_t v = data[srcIdx];
117
+ v = (j % 2 == 0) ? v & 0xf : v >> 4;
120
118
 
121
- out[i * dims + j] = v;
122
- }
123
- }
124
- } else if (bitsPerCode == 5) {
119
+ out[i * dims + j] = v;
120
+ }
121
+ }
122
+ } else if (bitsPerCode == 5) {
125
123
  #pragma omp parallel for
126
- for (int i = 0; i < numVecs; ++i) {
127
- for (int j = 0; j < dims; ++j) {
128
- int lo = i * srcVecSize + (j * 5) / 8;
129
- int hi = lo + 1;
124
+ for (int i = 0; i < numVecs; ++i) {
125
+ for (int j = 0; j < dims; ++j) {
126
+ int lo = i * srcVecSize + (j * 5) / 8;
127
+ int hi = lo + 1;
130
128
 
131
- FAISS_ASSERT(lo < data.size());
132
- FAISS_ASSERT(hi <= data.size());
129
+ FAISS_ASSERT(lo < data.size());
130
+ FAISS_ASSERT(hi <= data.size());
133
131
 
134
- auto vLower = data[lo];
135
- auto vUpper = hi < data.size() ? data[hi] : 0;
132
+ auto vLower = data[lo];
133
+ auto vUpper = hi < data.size() ? data[hi] : 0;
136
134
 
137
- out[i * dims + j] = unpack5(j, vLower, vUpper);
138
- }
139
- }
140
- } else if (bitsPerCode == 6) {
135
+ out[i * dims + j] = unpack5(j, vLower, vUpper);
136
+ }
137
+ }
138
+ } else if (bitsPerCode == 6) {
141
139
  #pragma omp parallel for
142
- for (int i = 0; i < numVecs; ++i) {
143
- for (int j = 0; j < dims; ++j) {
144
- int lo = i * srcVecSize + (j * 6) / 8;
145
- int hi = lo + 1;
140
+ for (int i = 0; i < numVecs; ++i) {
141
+ for (int j = 0; j < dims; ++j) {
142
+ int lo = i * srcVecSize + (j * 6) / 8;
143
+ int hi = lo + 1;
146
144
 
147
- FAISS_ASSERT(lo < data.size());
148
- FAISS_ASSERT(hi <= data.size());
145
+ FAISS_ASSERT(lo < data.size());
146
+ FAISS_ASSERT(hi <= data.size());
149
147
 
150
- auto vLower = data[lo];
151
- auto vUpper = hi < data.size() ? data[hi] : 0;
148
+ auto vLower = data[lo];
149
+ auto vUpper = hi < data.size() ? data[hi] : 0;
152
150
 
153
- out[i * dims + j] = unpack6(j, vLower, vUpper);
154
- }
151
+ out[i * dims + j] = unpack6(j, vLower, vUpper);
152
+ }
153
+ }
154
+ } else {
155
+ // unhandled
156
+ FAISS_ASSERT(false);
155
157
  }
156
- } else {
157
- // unhandled
158
- FAISS_ASSERT(false);
159
- }
160
158
 
161
- return out;
159
+ return out;
162
160
  }
163
161
 
164
162
  template <typename T>
165
- void
166
- unpackInterleavedWord(const T* in,
167
- T* out,
168
- int numVecs,
169
- int dims,
170
- int bitsPerCode) {
171
- int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
172
- int wordsPerBlock = wordsPerDimBlock * dims;
173
- int numBlocks = utils::divUp(numVecs, 32);
163
+ void unpackInterleavedWord(
164
+ const T* in,
165
+ T* out,
166
+ int numVecs,
167
+ int dims,
168
+ int bitsPerCode) {
169
+ int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
170
+ int wordsPerBlock = wordsPerDimBlock * dims;
171
+ int numBlocks = utils::divUp(numVecs, 32);
174
172
 
175
173
  #pragma omp parallel for
176
- for (int i = 0; i < numVecs; ++i) {
177
- int block = i / 32;
178
- FAISS_ASSERT(block < numBlocks);
179
- int lane = i % 32;
180
-
181
- for (int j = 0; j < dims; ++j) {
182
- int srcOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
183
- out[i * dims + j] = in[srcOffset];
174
+ for (int i = 0; i < numVecs; ++i) {
175
+ int block = i / 32;
176
+ FAISS_ASSERT(block < numBlocks);
177
+ int lane = i % 32;
178
+
179
+ for (int j = 0; j < dims; ++j) {
180
+ int srcOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
181
+ out[i * dims + j] = in[srcOffset];
182
+ }
184
183
  }
185
- }
186
184
  }
187
185
 
188
- std::vector<uint8_t>
189
- unpackInterleaved(std::vector<uint8_t> data,
190
- int numVecs,
191
- int dims,
192
- int bitsPerCode) {
193
- int bytesPerDimBlock = 32 * bitsPerCode / 8;
194
- int bytesPerBlock = bytesPerDimBlock * dims;
195
- int numBlocks = utils::divUp(numVecs, 32);
196
- size_t totalSize = (size_t) bytesPerBlock * numBlocks;
197
- FAISS_ASSERT(data.size() == totalSize);
198
-
199
- // bit codes padded to whole bytes
200
- std::vector<uint8_t> out(numVecs * dims * utils::divUp(bitsPerCode, 8));
201
-
202
- if (bitsPerCode == 8) {
203
- unpackInterleavedWord<uint8_t>(data.data(), out.data(),
204
- numVecs, dims, bitsPerCode);
205
- } else if (bitsPerCode == 16) {
206
- unpackInterleavedWord<uint16_t>((uint16_t*) data.data(),
207
- (uint16_t*) out.data(),
208
- numVecs, dims, bitsPerCode);
209
- } else if (bitsPerCode == 32) {
210
- unpackInterleavedWord<uint32_t>((uint32_t*) data.data(),
211
- (uint32_t*) out.data(),
212
- numVecs, dims, bitsPerCode);
213
- } else if (bitsPerCode == 4) {
186
+ std::vector<uint8_t> unpackInterleaved(
187
+ std::vector<uint8_t> data,
188
+ int numVecs,
189
+ int dims,
190
+ int bitsPerCode) {
191
+ int bytesPerDimBlock = 32 * bitsPerCode / 8;
192
+ int bytesPerBlock = bytesPerDimBlock * dims;
193
+ int numBlocks = utils::divUp(numVecs, 32);
194
+ size_t totalSize = (size_t)bytesPerBlock * numBlocks;
195
+ FAISS_ASSERT(data.size() == totalSize);
196
+
197
+ // bit codes padded to whole bytes
198
+ std::vector<uint8_t> out(numVecs * dims * utils::divUp(bitsPerCode, 8));
199
+
200
+ if (bitsPerCode == 8) {
201
+ unpackInterleavedWord<uint8_t>(
202
+ data.data(), out.data(), numVecs, dims, bitsPerCode);
203
+ } else if (bitsPerCode == 16) {
204
+ unpackInterleavedWord<uint16_t>(
205
+ (uint16_t*)data.data(),
206
+ (uint16_t*)out.data(),
207
+ numVecs,
208
+ dims,
209
+ bitsPerCode);
210
+ } else if (bitsPerCode == 32) {
211
+ unpackInterleavedWord<uint32_t>(
212
+ (uint32_t*)data.data(),
213
+ (uint32_t*)out.data(),
214
+ numVecs,
215
+ dims,
216
+ bitsPerCode);
217
+ } else if (bitsPerCode == 4) {
214
218
  #pragma omp parallel for
215
- for (int i = 0; i < numVecs; ++i) {
216
- int block = i / 32;
217
- int lane = i % 32;
219
+ for (int i = 0; i < numVecs; ++i) {
220
+ int block = i / 32;
221
+ int lane = i % 32;
218
222
 
219
- int word = lane / 2;
220
- int subWord = lane % 2;
223
+ int word = lane / 2;
224
+ int subWord = lane % 2;
221
225
 
222
- for (int j = 0; j < dims; ++j) {
223
- auto v =
224
- data[block * bytesPerBlock + j * bytesPerDimBlock + word];
226
+ for (int j = 0; j < dims; ++j) {
227
+ auto v =
228
+ data[block * bytesPerBlock + j * bytesPerDimBlock +
229
+ word];
225
230
 
226
- v = (subWord == 0) ? v & 0xf : v >> 4;
227
- out[i * dims + j] = v;
228
- }
229
- }
230
- } else if (bitsPerCode == 5) {
231
+ v = (subWord == 0) ? v & 0xf : v >> 4;
232
+ out[i * dims + j] = v;
233
+ }
234
+ }
235
+ } else if (bitsPerCode == 5) {
231
236
  #pragma omp parallel for
232
- for (int i = 0; i < numVecs; ++i) {
233
- int block = i / 32;
234
- int blockVector = i % 32;
237
+ for (int i = 0; i < numVecs; ++i) {
238
+ int block = i / 32;
239
+ int blockVector = i % 32;
235
240
 
236
- for (int j = 0; j < dims; ++j) {
237
- uint8_t* dimBlock =
238
- &data[block * bytesPerBlock + j * bytesPerDimBlock];
241
+ for (int j = 0; j < dims; ++j) {
242
+ uint8_t* dimBlock =
243
+ &data[block * bytesPerBlock + j * bytesPerDimBlock];
239
244
 
240
- int lo = (blockVector * 5) / 8;
241
- int hi = lo + 1;
245
+ int lo = (blockVector * 5) / 8;
246
+ int hi = lo + 1;
242
247
 
243
- FAISS_ASSERT(lo < bytesPerDimBlock);
244
- FAISS_ASSERT(hi <= bytesPerDimBlock);
248
+ FAISS_ASSERT(lo < bytesPerDimBlock);
249
+ FAISS_ASSERT(hi <= bytesPerDimBlock);
245
250
 
246
- auto vLower = dimBlock[lo];
247
- auto vUpper = hi < bytesPerDimBlock ? dimBlock[hi] : 0;
251
+ auto vLower = dimBlock[lo];
252
+ auto vUpper = hi < bytesPerDimBlock ? dimBlock[hi] : 0;
248
253
 
249
- out[i * dims + j] = unpack5(blockVector, vLower, vUpper);
250
- }
251
- }
252
- } else if (bitsPerCode == 6) {
254
+ out[i * dims + j] = unpack5(blockVector, vLower, vUpper);
255
+ }
256
+ }
257
+ } else if (bitsPerCode == 6) {
253
258
  #pragma omp parallel for
254
- for (int i = 0; i < numVecs; ++i) {
255
- int block = i / 32;
256
- int blockVector = i % 32;
259
+ for (int i = 0; i < numVecs; ++i) {
260
+ int block = i / 32;
261
+ int blockVector = i % 32;
257
262
 
258
- for (int j = 0; j < dims; ++j) {
259
- uint8_t* dimBlock =
260
- &data[block * bytesPerBlock + j * bytesPerDimBlock];
263
+ for (int j = 0; j < dims; ++j) {
264
+ uint8_t* dimBlock =
265
+ &data[block * bytesPerBlock + j * bytesPerDimBlock];
261
266
 
262
- int lo = (blockVector * 6) / 8;
263
- int hi = lo + 1;
267
+ int lo = (blockVector * 6) / 8;
268
+ int hi = lo + 1;
264
269
 
265
- FAISS_ASSERT(lo < bytesPerDimBlock);
266
- FAISS_ASSERT(hi <= bytesPerDimBlock);
270
+ FAISS_ASSERT(lo < bytesPerDimBlock);
271
+ FAISS_ASSERT(hi <= bytesPerDimBlock);
267
272
 
268
- auto vLower = dimBlock[lo];
269
- auto vUpper = hi < bytesPerDimBlock ? dimBlock[hi] : 0;
273
+ auto vLower = dimBlock[lo];
274
+ auto vUpper = hi < bytesPerDimBlock ? dimBlock[hi] : 0;
270
275
 
271
- out[i * dims + j] = unpack6(blockVector, vLower, vUpper);
272
- }
276
+ out[i * dims + j] = unpack6(blockVector, vLower, vUpper);
277
+ }
278
+ }
279
+ } else {
280
+ // unimplemented
281
+ FAISS_ASSERT(false);
273
282
  }
274
- } else {
275
- // unimplemented
276
- FAISS_ASSERT(false);
277
- }
278
283
 
279
- return out;
284
+ return out;
280
285
  }
281
286
 
282
287
  inline uint8_t pack5(int i, uint8_t lo, uint8_t hi, uint8_t hi2) {
283
- FAISS_ASSERT((lo & 0x1f) == lo);
284
- FAISS_ASSERT((hi & 0x1f) == hi);
285
- FAISS_ASSERT((hi2 & 0x1f) == hi2);
286
-
287
- uint8_t v = 0;
288
-
289
- // lsb ... msb
290
- // 0: 0 0 0 0 0 1 1 1
291
- // 1: 1 1 2 2 2 2 2 3
292
- // 2: 3 3 3 3 4 4 4 4
293
- // 3: 4 5 5 5 5 5 6 6
294
- // 4: 6 6 6 7 7 7 7 7
295
- switch (i % 5) {
296
- case 0:
297
- // 5 msbs of lower as vOut lsbs
298
- // 3 lsbs of upper as vOut msbs
299
- v = (lo & 0x1f) | (hi << 5);
300
- break;
301
- case 1:
302
- // 2 msbs of lower as vOut lsbs
303
- // 5 lsbs of upper as vOut msbs
304
- // 1 lsbs of upper2 as vOut msb
305
- v = (lo >> 3) | (hi << 2) | (hi2 << 7);
306
- break;
307
- case 2:
308
- // 4 msbs of lower as vOut lsbs
309
- // 4 lsbs of upper as vOut msbs
310
- v = (lo >> 1) | (hi << 4);
311
- break;
312
- case 3:
313
- // 1 msbs of lower as vOut lsbs
314
- // 5 lsbs of upper as vOut msbs
315
- // 2 lsbs of upper2 as vOut msb
316
- v = (lo >> 4) | (hi << 1) | (hi2 << 6);
317
- break;
318
- case 4:
319
- // 3 msbs of lower as vOut lsbs
320
- // 5 lsbs of upper as vOut msbs
321
- v = (lo >> 2) | (hi << 3);
322
- break;
323
- }
324
-
325
- return v;
326
- }
288
+ FAISS_ASSERT((lo & 0x1f) == lo);
289
+ FAISS_ASSERT((hi & 0x1f) == hi);
290
+ FAISS_ASSERT((hi2 & 0x1f) == hi2);
291
+
292
+ uint8_t v = 0;
293
+
294
+ // lsb ... msb
295
+ // 0: 0 0 0 0 0 1 1 1
296
+ // 1: 1 1 2 2 2 2 2 3
297
+ // 2: 3 3 3 3 4 4 4 4
298
+ // 3: 4 5 5 5 5 5 6 6
299
+ // 4: 6 6 6 7 7 7 7 7
300
+ switch (i % 5) {
301
+ case 0:
302
+ // 5 msbs of lower as vOut lsbs
303
+ // 3 lsbs of upper as vOut msbs
304
+ v = (lo & 0x1f) | (hi << 5);
305
+ break;
306
+ case 1:
307
+ // 2 msbs of lower as vOut lsbs
308
+ // 5 lsbs of upper as vOut msbs
309
+ // 1 lsbs of upper2 as vOut msb
310
+ v = (lo >> 3) | (hi << 2) | (hi2 << 7);
311
+ break;
312
+ case 2:
313
+ // 4 msbs of lower as vOut lsbs
314
+ // 4 lsbs of upper as vOut msbs
315
+ v = (lo >> 1) | (hi << 4);
316
+ break;
317
+ case 3:
318
+ // 1 msbs of lower as vOut lsbs
319
+ // 5 lsbs of upper as vOut msbs
320
+ // 2 lsbs of upper2 as vOut msb
321
+ v = (lo >> 4) | (hi << 1) | (hi2 << 6);
322
+ break;
323
+ case 4:
324
+ // 3 msbs of lower as vOut lsbs
325
+ // 5 lsbs of upper as vOut msbs
326
+ v = (lo >> 2) | (hi << 3);
327
+ break;
328
+ }
327
329
 
328
- inline uint8_t pack6(int i, uint8_t lo, uint8_t hi) {
329
- FAISS_ASSERT((lo & 0x3f) == lo);
330
- FAISS_ASSERT((hi & 0x3f) == hi);
331
-
332
- uint8_t v = 0;
333
-
334
- // lsb ... msb
335
- // 0: 0 0 0 0 0 0 1 1
336
- // 1: 1 1 1 1 2 2 2 2
337
- // 2: 2 2 3 3 3 3 3 3
338
- switch (i % 3) {
339
- case 0:
340
- // 6 msbs of lower as vOut lsbs
341
- // 2 lsbs of upper as vOut msbs
342
- v = (lo & 0x3f) | (hi << 6);
343
- break;
344
- case 1:
345
- // 4 msbs of lower as vOut lsbs
346
- // 4 lsbs of upper as vOut msbs
347
- v = (lo >> 2) | (hi << 4);
348
- break;
349
- case 2:
350
- // 2 msbs of lower as vOut lsbs
351
- // 6 lsbs of upper as vOut msbs
352
- v = (lo >> 4) | (hi << 2);
353
- break;
354
- }
355
-
356
- return v;
330
+ return v;
357
331
  }
358
332
 
333
+ inline uint8_t pack6(int i, uint8_t lo, uint8_t hi) {
334
+ FAISS_ASSERT((lo & 0x3f) == lo);
335
+ FAISS_ASSERT((hi & 0x3f) == hi);
336
+
337
+ uint8_t v = 0;
338
+
339
+ // lsb ... msb
340
+ // 0: 0 0 0 0 0 0 1 1
341
+ // 1: 1 1 1 1 2 2 2 2
342
+ // 2: 2 2 3 3 3 3 3 3
343
+ switch (i % 3) {
344
+ case 0:
345
+ // 6 msbs of lower as vOut lsbs
346
+ // 2 lsbs of upper as vOut msbs
347
+ v = (lo & 0x3f) | (hi << 6);
348
+ break;
349
+ case 1:
350
+ // 4 msbs of lower as vOut lsbs
351
+ // 4 lsbs of upper as vOut msbs
352
+ v = (lo >> 2) | (hi << 4);
353
+ break;
354
+ case 2:
355
+ // 2 msbs of lower as vOut lsbs
356
+ // 6 lsbs of upper as vOut msbs
357
+ v = (lo >> 4) | (hi << 2);
358
+ break;
359
+ }
359
360
 
360
- std::vector<uint8_t>
361
- packNonInterleaved(std::vector<uint8_t> data,
362
- int numVecs,
363
- int dims,
364
- int bitsPerCode) {
365
- // bit codes padded to whole bytes
366
- FAISS_ASSERT(data.size() == numVecs * dims * utils::divUp(bitsPerCode, 8));
361
+ return v;
362
+ }
367
363
 
368
- if (bitsPerCode == 8 ||
369
- bitsPerCode == 16 ||
370
- bitsPerCode == 32) {
371
- // nothing to do, whole words are already where they need to be
372
- return data;
373
- }
364
+ std::vector<uint8_t> packNonInterleaved(
365
+ std::vector<uint8_t> data,
366
+ int numVecs,
367
+ int dims,
368
+ int bitsPerCode) {
369
+ // bit codes padded to whole bytes
370
+ FAISS_ASSERT(data.size() == numVecs * dims * utils::divUp(bitsPerCode, 8));
371
+
372
+ if (bitsPerCode == 8 || bitsPerCode == 16 || bitsPerCode == 32) {
373
+ // nothing to do, whole words are already where they need to be
374
+ return data;
375
+ }
374
376
 
375
- // bits packed into a whole number of bytes
376
- int bytesPerVec = utils::divUp(dims * bitsPerCode, 8);
377
+ // bits packed into a whole number of bytes
378
+ int bytesPerVec = utils::divUp(dims * bitsPerCode, 8);
377
379
 
378
- std::vector<uint8_t> out(numVecs * bytesPerVec);
380
+ std::vector<uint8_t> out(numVecs * bytesPerVec);
379
381
 
380
- if (bitsPerCode == 4) {
382
+ if (bitsPerCode == 4) {
381
383
  #pragma omp parallel for
382
- for (int i = 0; i < numVecs; ++i) {
383
- for (int j = 0; j < bytesPerVec; ++j) {
384
- int dimLo = j * 2;
385
- int dimHi = dimLo + 1;
386
- FAISS_ASSERT(dimLo < dims);
387
- FAISS_ASSERT(dimHi <= dims);
388
-
389
- uint8_t lo = data[i * dims + dimLo];
390
- uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
391
-
392
- out[i * bytesPerVec + j] = (hi << 4) | (lo & 0xf);
393
- }
394
- }
395
- } else if (bitsPerCode == 5) {
384
+ for (int i = 0; i < numVecs; ++i) {
385
+ for (int j = 0; j < bytesPerVec; ++j) {
386
+ int dimLo = j * 2;
387
+ int dimHi = dimLo + 1;
388
+ FAISS_ASSERT(dimLo < dims);
389
+ FAISS_ASSERT(dimHi <= dims);
390
+
391
+ uint8_t lo = data[i * dims + dimLo];
392
+ uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
393
+
394
+ out[i * bytesPerVec + j] = (hi << 4) | (lo & 0xf);
395
+ }
396
+ }
397
+ } else if (bitsPerCode == 5) {
396
398
  #pragma omp parallel for
397
- for (int i = 0; i < numVecs; ++i) {
398
- for (int j = 0; j < bytesPerVec; ++j) {
399
- int dimLo = (j * 8) / 5;
400
- int dimHi = dimLo + 1;
401
- int dimHi2 = dimHi + 1;
402
- FAISS_ASSERT(dimLo < dims);
403
- FAISS_ASSERT(dimHi <= dims);
404
- FAISS_ASSERT(dimHi <= dims + 1);
405
-
406
- uint8_t lo = data[i * dims + dimLo];
407
- uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
408
- uint8_t hi2 = dimHi2 < dims ? data[i * dims + dimHi2] : 0;
409
-
410
- out[i * bytesPerVec + j] = pack5(j, lo, hi, hi2);
411
- }
412
- }
413
- } else if (bitsPerCode == 6) {
399
+ for (int i = 0; i < numVecs; ++i) {
400
+ for (int j = 0; j < bytesPerVec; ++j) {
401
+ int dimLo = (j * 8) / 5;
402
+ int dimHi = dimLo + 1;
403
+ int dimHi2 = dimHi + 1;
404
+ FAISS_ASSERT(dimLo < dims);
405
+ FAISS_ASSERT(dimHi <= dims);
406
+ FAISS_ASSERT(dimHi <= dims + 1);
407
+
408
+ uint8_t lo = data[i * dims + dimLo];
409
+ uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
410
+ uint8_t hi2 = dimHi2 < dims ? data[i * dims + dimHi2] : 0;
411
+
412
+ out[i * bytesPerVec + j] = pack5(j, lo, hi, hi2);
413
+ }
414
+ }
415
+ } else if (bitsPerCode == 6) {
414
416
  #pragma omp parallel for
415
- for (int i = 0; i < numVecs; ++i) {
416
- for (int j = 0; j < bytesPerVec; ++j) {
417
- int dimLo = (j * 8) / 6;
418
- int dimHi = dimLo + 1;
419
- FAISS_ASSERT(dimLo < dims);
420
- FAISS_ASSERT(dimHi <= dims);
421
-
422
- uint8_t lo = data[i * dims + dimLo];
423
- uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
424
-
425
- out[i * bytesPerVec + j] = pack6(j, lo, hi);
426
- }
417
+ for (int i = 0; i < numVecs; ++i) {
418
+ for (int j = 0; j < bytesPerVec; ++j) {
419
+ int dimLo = (j * 8) / 6;
420
+ int dimHi = dimLo + 1;
421
+ FAISS_ASSERT(dimLo < dims);
422
+ FAISS_ASSERT(dimHi <= dims);
423
+
424
+ uint8_t lo = data[i * dims + dimLo];
425
+ uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
426
+
427
+ out[i * bytesPerVec + j] = pack6(j, lo, hi);
428
+ }
429
+ }
430
+ } else {
431
+ // unhandled
432
+ FAISS_ASSERT(false);
427
433
  }
428
- } else {
429
- // unhandled
430
- FAISS_ASSERT(false);
431
- }
432
434
 
433
- return out;
435
+ return out;
434
436
  }
435
437
 
436
438
  template <typename T>
437
- void
438
- packInterleavedWord(const T* in,
439
- T* out,
440
- int numVecs,
441
- int dims,
442
- int bitsPerCode) {
443
- int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
444
- int wordsPerBlock = wordsPerDimBlock * dims;
445
- int numBlocks = utils::divUp(numVecs, 32);
446
-
447
- // We're guaranteed that all other slots not filled by the vectors present are
448
- // initialized to zero (from the vector constructor in packInterleaved)
439
+ void packInterleavedWord(
440
+ const T* in,
441
+ T* out,
442
+ int numVecs,
443
+ int dims,
444
+ int bitsPerCode) {
445
+ int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
446
+ int wordsPerBlock = wordsPerDimBlock * dims;
447
+ int numBlocks = utils::divUp(numVecs, 32);
448
+
449
+ // We're guaranteed that all other slots not filled by the vectors present
450
+ // are initialized to zero (from the vector constructor in packInterleaved)
449
451
  #pragma omp parallel for
450
- for (int i = 0; i < numVecs; ++i) {
451
- int block = i / 32;
452
- FAISS_ASSERT(block < numBlocks);
453
- int lane = i % 32;
454
-
455
- for (int j = 0; j < dims; ++j) {
456
- int dstOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
457
- out[dstOffset] = in[i * dims + j];
452
+ for (int i = 0; i < numVecs; ++i) {
453
+ int block = i / 32;
454
+ FAISS_ASSERT(block < numBlocks);
455
+ int lane = i % 32;
456
+
457
+ for (int j = 0; j < dims; ++j) {
458
+ int dstOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
459
+ out[dstOffset] = in[i * dims + j];
460
+ }
458
461
  }
459
- }
460
462
  }
461
463
 
462
- std::vector<uint8_t>
463
- packInterleaved(std::vector<uint8_t> data,
464
- int numVecs,
465
- int dims,
466
- int bitsPerCode) {
467
- int bytesPerDimBlock = 32 * bitsPerCode / 8;
468
- int bytesPerBlock = bytesPerDimBlock * dims;
469
- int numBlocks = utils::divUp(numVecs, 32);
470
- size_t totalSize = (size_t) bytesPerBlock * numBlocks;
471
-
472
- // bit codes padded to whole bytes
473
- FAISS_ASSERT(data.size() == numVecs * dims * utils::divUp(bitsPerCode, 8));
474
-
475
- // packs based on blocks
476
- std::vector<uint8_t> out(totalSize, 0);
477
-
478
- if (bitsPerCode == 8) {
479
- packInterleavedWord<uint8_t>(data.data(), out.data(),
480
- numVecs, dims, bitsPerCode);
481
- } else if (bitsPerCode == 16) {
482
- packInterleavedWord<uint16_t>((uint16_t*) data.data(),
483
- (uint16_t*) out.data(),
484
- numVecs, dims, bitsPerCode);
485
- } else if (bitsPerCode == 32) {
486
- packInterleavedWord<uint32_t>((uint32_t*) data.data(),
487
- (uint32_t*) out.data(),
488
- numVecs, dims, bitsPerCode);
489
- } else if (bitsPerCode == 4) {
464
+ std::vector<uint8_t> packInterleaved(
465
+ std::vector<uint8_t> data,
466
+ int numVecs,
467
+ int dims,
468
+ int bitsPerCode) {
469
+ int bytesPerDimBlock = 32 * bitsPerCode / 8;
470
+ int bytesPerBlock = bytesPerDimBlock * dims;
471
+ int numBlocks = utils::divUp(numVecs, 32);
472
+ size_t totalSize = (size_t)bytesPerBlock * numBlocks;
473
+
474
+ // bit codes padded to whole bytes
475
+ FAISS_ASSERT(data.size() == numVecs * dims * utils::divUp(bitsPerCode, 8));
476
+
477
+ // packs based on blocks
478
+ std::vector<uint8_t> out(totalSize, 0);
479
+
480
+ if (bitsPerCode == 8) {
481
+ packInterleavedWord<uint8_t>(
482
+ data.data(), out.data(), numVecs, dims, bitsPerCode);
483
+ } else if (bitsPerCode == 16) {
484
+ packInterleavedWord<uint16_t>(
485
+ (uint16_t*)data.data(),
486
+ (uint16_t*)out.data(),
487
+ numVecs,
488
+ dims,
489
+ bitsPerCode);
490
+ } else if (bitsPerCode == 32) {
491
+ packInterleavedWord<uint32_t>(
492
+ (uint32_t*)data.data(),
493
+ (uint32_t*)out.data(),
494
+ numVecs,
495
+ dims,
496
+ bitsPerCode);
497
+ } else if (bitsPerCode == 4) {
490
498
  #pragma omp parallel for
491
- for (int i = 0; i < numBlocks; ++i) {
492
- for (int j = 0; j < dims; ++j) {
493
- for (int k = 0; k < bytesPerDimBlock; ++k) {
494
- int loVec = i * 32 + k * 2;
495
- int hiVec = loVec + 1;
496
-
497
- uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
498
- uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
499
-
500
- out[i * bytesPerBlock + j * bytesPerDimBlock + k] =
501
- (hi << 4) | (lo & 0xf);
499
+ for (int i = 0; i < numBlocks; ++i) {
500
+ for (int j = 0; j < dims; ++j) {
501
+ for (int k = 0; k < bytesPerDimBlock; ++k) {
502
+ int loVec = i * 32 + k * 2;
503
+ int hiVec = loVec + 1;
504
+
505
+ uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
506
+ uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
507
+
508
+ out[i * bytesPerBlock + j * bytesPerDimBlock + k] =
509
+ (hi << 4) | (lo & 0xf);
510
+ }
511
+ }
502
512
  }
503
- }
504
- }
505
- } else if (bitsPerCode == 5) {
513
+ } else if (bitsPerCode == 5) {
506
514
  #pragma omp parallel for
507
- for (int i = 0; i < numBlocks; ++i) {
508
- for (int j = 0; j < dims; ++j) {
509
- for (int k = 0; k < bytesPerDimBlock; ++k) {
510
- // What input vectors we are pulling from
511
- int loVec = i * 32 + (k * 8) / 5;
512
- int hiVec = loVec + 1;
513
- int hiVec2 = hiVec + 1;
514
-
515
- uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
516
- uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
517
- uint8_t hi2 = hiVec2 < numVecs ? data[hiVec2 * dims + j] : 0;
518
-
519
- out[i * bytesPerBlock + j * bytesPerDimBlock + k] = pack5(k, lo, hi, hi2);
515
+ for (int i = 0; i < numBlocks; ++i) {
516
+ for (int j = 0; j < dims; ++j) {
517
+ for (int k = 0; k < bytesPerDimBlock; ++k) {
518
+ // What input vectors we are pulling from
519
+ int loVec = i * 32 + (k * 8) / 5;
520
+ int hiVec = loVec + 1;
521
+ int hiVec2 = hiVec + 1;
522
+
523
+ uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
524
+ uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
525
+ uint8_t hi2 =
526
+ hiVec2 < numVecs ? data[hiVec2 * dims + j] : 0;
527
+
528
+ out[i * bytesPerBlock + j * bytesPerDimBlock + k] =
529
+ pack5(k, lo, hi, hi2);
530
+ }
531
+ }
520
532
  }
521
- }
522
- }
523
- } else if (bitsPerCode == 6) {
533
+ } else if (bitsPerCode == 6) {
524
534
  #pragma omp parallel for
525
- for (int i = 0; i < numBlocks; ++i) {
526
- for (int j = 0; j < dims; ++j) {
527
- for (int k = 0; k < bytesPerDimBlock; ++k) {
528
- // What input vectors we are pulling from
529
- int loVec = i * 32 + (k * 8) / 6;
530
- int hiVec = loVec + 1;
531
-
532
- uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
533
- uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
534
-
535
- out[i * bytesPerBlock + j * bytesPerDimBlock + k] = pack6(k, lo, hi);
535
+ for (int i = 0; i < numBlocks; ++i) {
536
+ for (int j = 0; j < dims; ++j) {
537
+ for (int k = 0; k < bytesPerDimBlock; ++k) {
538
+ // What input vectors we are pulling from
539
+ int loVec = i * 32 + (k * 8) / 6;
540
+ int hiVec = loVec + 1;
541
+
542
+ uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
543
+ uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
544
+
545
+ out[i * bytesPerBlock + j * bytesPerDimBlock + k] =
546
+ pack6(k, lo, hi);
547
+ }
548
+ }
536
549
  }
537
- }
550
+ } else {
551
+ // unimplemented
552
+ FAISS_ASSERT(false);
538
553
  }
539
- } else {
540
- // unimplemented
541
- FAISS_ASSERT(false);
542
- }
543
554
 
544
- return out;
555
+ return out;
545
556
  }
546
557
 
547
- } } // namespace
558
+ } // namespace gpu
559
+ } // namespace faiss