faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -5,138 +5,138 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #pragma once
10
9
 
11
10
  #include <faiss/gpu/GpuResources.h>
12
- #include <faiss/gpu/utils/StackDeviceMemory.h>
13
11
  #include <faiss/gpu/utils/DeviceUtils.h>
12
+ #include <faiss/gpu/utils/StackDeviceMemory.h>
14
13
  #include <functional>
15
14
  #include <map>
16
15
  #include <unordered_map>
17
16
  #include <vector>
18
17
 
19
- namespace faiss { namespace gpu {
18
+ namespace faiss {
19
+ namespace gpu {
20
20
 
21
21
  /// Standard implementation of the GpuResources object that provides for a
22
22
  /// temporary memory manager
23
23
  class StandardGpuResourcesImpl : public GpuResources {
24
- public:
25
- StandardGpuResourcesImpl();
24
+ public:
25
+ StandardGpuResourcesImpl();
26
26
 
27
- ~StandardGpuResourcesImpl() override;
27
+ ~StandardGpuResourcesImpl() override;
28
28
 
29
- /// Disable allocation of temporary memory; all temporary memory
30
- /// requests will call cudaMalloc / cudaFree at the point of use
31
- void noTempMemory();
29
+ /// Disable allocation of temporary memory; all temporary memory
30
+ /// requests will call cudaMalloc / cudaFree at the point of use
31
+ void noTempMemory();
32
32
 
33
- /// Specify that we wish to use a certain fixed size of memory on
34
- /// all devices as temporary memory. This is the upper bound for the GPU
35
- /// memory that we will reserve. We will never go above 1.5 GiB on any GPU;
36
- /// smaller GPUs (with <= 4 GiB or <= 8 GiB) will use less memory than that.
37
- /// To avoid any temporary memory allocation, pass 0.
38
- void setTempMemory(size_t size);
33
+ /// Specify that we wish to use a certain fixed size of memory on
34
+ /// all devices as temporary memory. This is the upper bound for the GPU
35
+ /// memory that we will reserve. We will never go above 1.5 GiB on any GPU;
36
+ /// smaller GPUs (with <= 4 GiB or <= 8 GiB) will use less memory than that.
37
+ /// To avoid any temporary memory allocation, pass 0.
38
+ void setTempMemory(size_t size);
39
39
 
40
- /// Set amount of pinned memory to allocate, for async GPU <-> CPU
41
- /// transfers
42
- void setPinnedMemory(size_t size);
40
+ /// Set amount of pinned memory to allocate, for async GPU <-> CPU
41
+ /// transfers
42
+ void setPinnedMemory(size_t size);
43
43
 
44
- /// Called to change the stream for work ordering. We do not own `stream`;
45
- /// i.e., it will not be destroyed when the GpuResources object gets cleaned
46
- /// up.
47
- /// We are guaranteed that all Faiss GPU work is ordered with respect to
48
- /// this stream upon exit from an index or other Faiss GPU call.
49
- void setDefaultStream(int device, cudaStream_t stream) override;
44
+ /// Called to change the stream for work ordering. We do not own `stream`;
45
+ /// i.e., it will not be destroyed when the GpuResources object gets cleaned
46
+ /// up.
47
+ /// We are guaranteed that all Faiss GPU work is ordered with respect to
48
+ /// this stream upon exit from an index or other Faiss GPU call.
49
+ void setDefaultStream(int device, cudaStream_t stream) override;
50
50
 
51
- /// Revert the default stream to the original stream managed by this resources
52
- /// object, in case someone called `setDefaultStream`.
53
- void revertDefaultStream(int device);
51
+ /// Revert the default stream to the original stream managed by this
52
+ /// resources object, in case someone called `setDefaultStream`.
53
+ void revertDefaultStream(int device);
54
54
 
55
- /// Returns the stream for the given device on which all Faiss GPU work is
56
- /// ordered.
57
- /// We are guaranteed that all Faiss GPU work is ordered with respect to
58
- /// this stream upon exit from an index or other Faiss GPU call.
59
- cudaStream_t getDefaultStream(int device) override;
55
+ /// Returns the stream for the given device on which all Faiss GPU work is
56
+ /// ordered.
57
+ /// We are guaranteed that all Faiss GPU work is ordered with respect to
58
+ /// this stream upon exit from an index or other Faiss GPU call.
59
+ cudaStream_t getDefaultStream(int device) override;
60
60
 
61
- /// Called to change the work ordering streams to the null stream
62
- /// for all devices
63
- void setDefaultNullStreamAllDevices();
61
+ /// Called to change the work ordering streams to the null stream
62
+ /// for all devices
63
+ void setDefaultNullStreamAllDevices();
64
64
 
65
- /// If enabled, will print every GPU memory allocation and deallocation to
66
- /// standard output
67
- void setLogMemoryAllocations(bool enable);
65
+ /// If enabled, will print every GPU memory allocation and deallocation to
66
+ /// standard output
67
+ void setLogMemoryAllocations(bool enable);
68
68
 
69
- public:
70
- /// Internal system calls
69
+ public:
70
+ /// Internal system calls
71
71
 
72
- /// Initialize resources for this device
73
- void initializeForDevice(int device) override;
72
+ /// Initialize resources for this device
73
+ void initializeForDevice(int device) override;
74
74
 
75
- cublasHandle_t getBlasHandle(int device) override;
75
+ cublasHandle_t getBlasHandle(int device) override;
76
76
 
77
- std::vector<cudaStream_t> getAlternateStreams(int device) override;
77
+ std::vector<cudaStream_t> getAlternateStreams(int device) override;
78
78
 
79
- /// Allocate non-temporary GPU memory
80
- void* allocMemory(const AllocRequest& req) override;
79
+ /// Allocate non-temporary GPU memory
80
+ void* allocMemory(const AllocRequest& req) override;
81
81
 
82
- /// Returns a previous allocation
83
- void deallocMemory(int device, void* in) override;
82
+ /// Returns a previous allocation
83
+ void deallocMemory(int device, void* in) override;
84
84
 
85
- size_t getTempMemoryAvailable(int device) const override;
85
+ size_t getTempMemoryAvailable(int device) const override;
86
86
 
87
- /// Export a description of memory used for Python
88
- std::map<int, std::map<std::string, std::pair<int, size_t>>>
89
- getMemoryInfo() const;
87
+ /// Export a description of memory used for Python
88
+ std::map<int, std::map<std::string, std::pair<int, size_t>>> getMemoryInfo()
89
+ const;
90
90
 
91
- std::pair<void*, size_t> getPinnedMemory() override;
91
+ std::pair<void*, size_t> getPinnedMemory() override;
92
92
 
93
- cudaStream_t getAsyncCopyStream(int device) override;
93
+ cudaStream_t getAsyncCopyStream(int device) override;
94
94
 
95
- private:
96
- /// Have GPU resources been initialized for this device yet?
97
- bool isInitialized(int device) const;
95
+ private:
96
+ /// Have GPU resources been initialized for this device yet?
97
+ bool isInitialized(int device) const;
98
98
 
99
- /// Adjust the default temporary memory allocation based on the total GPU
100
- /// memory size
101
- static size_t getDefaultTempMemForGPU(int device, size_t requested);
99
+ /// Adjust the default temporary memory allocation based on the total GPU
100
+ /// memory size
101
+ static size_t getDefaultTempMemForGPU(int device, size_t requested);
102
102
 
103
- private:
104
- /// Set of currently outstanding memory allocations per device
105
- /// device -> (alloc request, allocated ptr)
106
- std::unordered_map<int, std::unordered_map<void*, AllocRequest>> allocs_;
103
+ private:
104
+ /// Set of currently outstanding memory allocations per device
105
+ /// device -> (alloc request, allocated ptr)
106
+ std::unordered_map<int, std::unordered_map<void*, AllocRequest>> allocs_;
107
107
 
108
- /// Temporary memory provider, per each device
109
- std::unordered_map<int, std::unique_ptr<StackDeviceMemory>> tempMemory_;
108
+ /// Temporary memory provider, per each device
109
+ std::unordered_map<int, std::unique_ptr<StackDeviceMemory>> tempMemory_;
110
110
 
111
- /// Our default stream that work is ordered on, one per each device
112
- std::unordered_map<int, cudaStream_t> defaultStreams_;
111
+ /// Our default stream that work is ordered on, one per each device
112
+ std::unordered_map<int, cudaStream_t> defaultStreams_;
113
113
 
114
- /// This contains particular streams as set by the user for
115
- /// ordering, if any
116
- std::unordered_map<int, cudaStream_t> userDefaultStreams_;
114
+ /// This contains particular streams as set by the user for
115
+ /// ordering, if any
116
+ std::unordered_map<int, cudaStream_t> userDefaultStreams_;
117
117
 
118
- /// Other streams we can use, per each device
119
- std::unordered_map<int, std::vector<cudaStream_t>> alternateStreams_;
118
+ /// Other streams we can use, per each device
119
+ std::unordered_map<int, std::vector<cudaStream_t>> alternateStreams_;
120
120
 
121
- /// Async copy stream to use for GPU <-> CPU pinned memory copies
122
- std::unordered_map<int, cudaStream_t> asyncCopyStreams_;
121
+ /// Async copy stream to use for GPU <-> CPU pinned memory copies
122
+ std::unordered_map<int, cudaStream_t> asyncCopyStreams_;
123
123
 
124
- /// cuBLAS handle for each device
125
- std::unordered_map<int, cublasHandle_t> blasHandles_;
124
+ /// cuBLAS handle for each device
125
+ std::unordered_map<int, cublasHandle_t> blasHandles_;
126
126
 
127
- /// Pinned memory allocation for use with this GPU
128
- void* pinnedMemAlloc_;
129
- size_t pinnedMemAllocSize_;
127
+ /// Pinned memory allocation for use with this GPU
128
+ void* pinnedMemAlloc_;
129
+ size_t pinnedMemAllocSize_;
130
130
 
131
- /// Another option is to use a specified amount of memory on all
132
- /// devices
133
- size_t tempMemSize_;
131
+ /// Another option is to use a specified amount of memory on all
132
+ /// devices
133
+ size_t tempMemSize_;
134
134
 
135
- /// Amount of pinned memory we should allocate
136
- size_t pinnedMemSize_;
135
+ /// Amount of pinned memory we should allocate
136
+ size_t pinnedMemSize_;
137
137
 
138
- /// Whether or not we log every GPU memory allocation and deallocation
139
- bool allocLogging_;
138
+ /// Whether or not we log every GPU memory allocation and deallocation
139
+ bool allocLogging_;
140
140
  };
141
141
 
142
142
  /// Default implementation of GpuResources that allocates a cuBLAS
@@ -144,61 +144,62 @@ class StandardGpuResourcesImpl : public GpuResources {
144
144
  /// Internally, the Faiss GPU code uses the instance managed by getResources,
145
145
  /// but this is the user-facing object that is internally reference counted.
146
146
  class StandardGpuResources : public GpuResourcesProvider {
147
- public:
148
- StandardGpuResources();
149
- ~StandardGpuResources() override;
147
+ public:
148
+ StandardGpuResources();
149
+ ~StandardGpuResources() override;
150
150
 
151
- std::shared_ptr<GpuResources> getResources() override;
151
+ std::shared_ptr<GpuResources> getResources() override;
152
152
 
153
- /// Disable allocation of temporary memory; all temporary memory
154
- /// requests will call cudaMalloc / cudaFree at the point of use
155
- void noTempMemory();
153
+ /// Disable allocation of temporary memory; all temporary memory
154
+ /// requests will call cudaMalloc / cudaFree at the point of use
155
+ void noTempMemory();
156
156
 
157
- /// Specify that we wish to use a certain fixed size of memory on
158
- /// all devices as temporary memory. This is the upper bound for the GPU
159
- /// memory that we will reserve. We will never go above 1.5 GiB on any GPU;
160
- /// smaller GPUs (with <= 4 GiB or <= 8 GiB) will use less memory than that.
161
- /// To avoid any temporary memory allocation, pass 0.
162
- void setTempMemory(size_t size);
157
+ /// Specify that we wish to use a certain fixed size of memory on
158
+ /// all devices as temporary memory. This is the upper bound for the GPU
159
+ /// memory that we will reserve. We will never go above 1.5 GiB on any GPU;
160
+ /// smaller GPUs (with <= 4 GiB or <= 8 GiB) will use less memory than that.
161
+ /// To avoid any temporary memory allocation, pass 0.
162
+ void setTempMemory(size_t size);
163
163
 
164
- /// Set amount of pinned memory to allocate, for async GPU <-> CPU
165
- /// transfers
166
- void setPinnedMemory(size_t size);
164
+ /// Set amount of pinned memory to allocate, for async GPU <-> CPU
165
+ /// transfers
166
+ void setPinnedMemory(size_t size);
167
167
 
168
- /// Called to change the stream for work ordering. We do not own `stream`;
169
- /// i.e., it will not be destroyed when the GpuResources object gets cleaned
170
- /// up.
171
- /// We are guaranteed that all Faiss GPU work is ordered with respect to
172
- /// this stream upon exit from an index or other Faiss GPU call.
173
- void setDefaultStream(int device, cudaStream_t stream);
168
+ /// Called to change the stream for work ordering. We do not own `stream`;
169
+ /// i.e., it will not be destroyed when the GpuResources object gets cleaned
170
+ /// up.
171
+ /// We are guaranteed that all Faiss GPU work is ordered with respect to
172
+ /// this stream upon exit from an index or other Faiss GPU call.
173
+ void setDefaultStream(int device, cudaStream_t stream);
174
174
 
175
- /// Revert the default stream to the original stream managed by this resources
176
- /// object, in case someone called `setDefaultStream`.
177
- void revertDefaultStream(int device);
175
+ /// Revert the default stream to the original stream managed by this
176
+ /// resources object, in case someone called `setDefaultStream`.
177
+ void revertDefaultStream(int device);
178
178
 
179
- /// Called to change the work ordering streams to the null stream
180
- /// for all devices
181
- void setDefaultNullStreamAllDevices();
179
+ /// Called to change the work ordering streams to the null stream
180
+ /// for all devices
181
+ void setDefaultNullStreamAllDevices();
182
182
 
183
- /// Export a description of memory used for Python
184
- std::map<int, std::map<std::string, std::pair<int, size_t>>>
185
- getMemoryInfo() const;
183
+ /// Export a description of memory used for Python
184
+ std::map<int, std::map<std::string, std::pair<int, size_t>>> getMemoryInfo()
185
+ const;
186
186
 
187
- /// Returns the current default stream
188
- cudaStream_t getDefaultStream(int device);
187
+ /// Returns the current default stream
188
+ cudaStream_t getDefaultStream(int device);
189
189
 
190
- /// Returns the current amount of temp memory available
191
- size_t getTempMemoryAvailable(int device) const;
190
+ /// Returns the current amount of temp memory available
191
+ size_t getTempMemoryAvailable(int device) const;
192
192
 
193
- /// Synchronize our default stream with the CPU
194
- void syncDefaultStreamCurrentDevice();
193
+ /// Synchronize our default stream with the CPU
194
+ void syncDefaultStreamCurrentDevice();
195
195
 
196
- /// If enabled, will print every GPU memory allocation and deallocation to
197
- /// standard output
198
- void setLogMemoryAllocations(bool enable);
196
+ /// If enabled, will print every GPU memory allocation and deallocation to
197
+ /// standard output
198
+ void setLogMemoryAllocations(bool enable);
199
199
 
200
- private:
201
- std::shared_ptr<StandardGpuResourcesImpl> res_;
200
+ private:
201
+ std::shared_ptr<StandardGpuResourcesImpl> res_;
202
202
  };
203
203
 
204
- } } // namespace
204
+ } // namespace gpu
205
+ } // namespace faiss
@@ -6,542 +6,554 @@
6
6
  */
7
7
 
8
8
  #include <faiss/gpu/impl/InterleavedCodes.h>
9
- #include <faiss/impl/FaissAssert.h>
10
9
  #include <faiss/gpu/utils/StaticUtils.h>
10
+ #include <faiss/impl/FaissAssert.h>
11
11
 
12
- namespace faiss { namespace gpu {
12
+ namespace faiss {
13
+ namespace gpu {
13
14
 
14
15
  inline uint8_t unpack5(int i, uint8_t vLower, uint8_t vUpper) {
15
- uint8_t v = 0;
16
-
17
- // lsb ... msb
18
- // 0: 0 0 0 0 0 1 1 1
19
- // 1: 1 1 2 2 2 2 2 3
20
- // 2: 3 3 3 3 4 4 4 4
21
- // 3: 4 5 5 5 5 5 6 6
22
- // 4: 6 6 6 7 7 7 7 7
23
- switch (i % 8) {
24
- case 0:
25
- // 5 lsbs of lower
26
- v = vLower & 0x1f;
27
- break;
28
- case 1:
29
- // 3 msbs of lower as v lsbs
30
- // 2 msbs of upper as v msbs
31
- v = (vLower >> 5) | ((vUpper & 0x3) << 3);
32
- break;
33
- case 2:
34
- // 5 of lower
35
- v = (vLower >> 2) & 0x1f;
36
- break;
37
- case 3:
38
- // 1 msbs of lower as v lsbs
39
- // 4 lsbs of upper as v msbs
40
- v = (vLower >> 7) | ((vUpper & 0xf) << 1);
41
- break;
42
- case 4:
43
- // 4 msbs of lower as v lsbs
44
- // 1 lsbs of upper as v msbs
45
- v = (vLower >> 4) | ((vUpper & 0x1) << 4);
46
- break;
47
- case 5:
48
- // 5 of lower
49
- v = (vLower >> 1) & 0x1f;
50
- break;
51
- case 6:
52
- // 2 msbs of lower as v lsbs
53
- // 3 lsbs of upper as v msbs
54
- v = (vLower >> 6) | ((vUpper & 0x7) << 2);
55
- break;
56
- case 7:
57
- // 5 of lower
58
- v = (vLower >> 3);
59
- break;
60
- }
61
-
62
- return v;
63
- }
16
+ uint8_t v = 0;
17
+
18
+ // lsb ... msb
19
+ // 0: 0 0 0 0 0 1 1 1
20
+ // 1: 1 1 2 2 2 2 2 3
21
+ // 2: 3 3 3 3 4 4 4 4
22
+ // 3: 4 5 5 5 5 5 6 6
23
+ // 4: 6 6 6 7 7 7 7 7
24
+ switch (i % 8) {
25
+ case 0:
26
+ // 5 lsbs of lower
27
+ v = vLower & 0x1f;
28
+ break;
29
+ case 1:
30
+ // 3 msbs of lower as v lsbs
31
+ // 2 msbs of upper as v msbs
32
+ v = (vLower >> 5) | ((vUpper & 0x3) << 3);
33
+ break;
34
+ case 2:
35
+ // 5 of lower
36
+ v = (vLower >> 2) & 0x1f;
37
+ break;
38
+ case 3:
39
+ // 1 msbs of lower as v lsbs
40
+ // 4 lsbs of upper as v msbs
41
+ v = (vLower >> 7) | ((vUpper & 0xf) << 1);
42
+ break;
43
+ case 4:
44
+ // 4 msbs of lower as v lsbs
45
+ // 1 lsbs of upper as v msbs
46
+ v = (vLower >> 4) | ((vUpper & 0x1) << 4);
47
+ break;
48
+ case 5:
49
+ // 5 of lower
50
+ v = (vLower >> 1) & 0x1f;
51
+ break;
52
+ case 6:
53
+ // 2 msbs of lower as v lsbs
54
+ // 3 lsbs of upper as v msbs
55
+ v = (vLower >> 6) | ((vUpper & 0x7) << 2);
56
+ break;
57
+ case 7:
58
+ // 5 of lower
59
+ v = (vLower >> 3);
60
+ break;
61
+ }
64
62
 
65
- inline uint8_t unpack6(int i, uint8_t vLower, uint8_t vUpper) {
66
- uint8_t v = 0;
67
-
68
- switch (i % 4) {
69
- case 0:
70
- // 6 lsbs of lower
71
- v = vLower & 0x3f;
72
- break;
73
- case 1:
74
- // 2 msbs of lower as v lsbs
75
- // 4 lsbs of upper as v msbs
76
- v = (vLower >> 6) | ((vUpper & 0xf) << 2);
77
- break;
78
- case 2:
79
- // 4 msbs of lower as v lsbs
80
- // 2 lsbs of upper as v msbs
81
- v = (vLower >> 4) | ((vUpper & 0x3) << 4);
82
- break;
83
- case 3:
84
- // 6 msbs of lower
85
- v = (vLower >> 2);
86
- break;
87
- }
88
-
89
- return v;
63
+ return v;
90
64
  }
91
65
 
66
+ inline uint8_t unpack6(int i, uint8_t vLower, uint8_t vUpper) {
67
+ uint8_t v = 0;
68
+
69
+ switch (i % 4) {
70
+ case 0:
71
+ // 6 lsbs of lower
72
+ v = vLower & 0x3f;
73
+ break;
74
+ case 1:
75
+ // 2 msbs of lower as v lsbs
76
+ // 4 lsbs of upper as v msbs
77
+ v = (vLower >> 6) | ((vUpper & 0xf) << 2);
78
+ break;
79
+ case 2:
80
+ // 4 msbs of lower as v lsbs
81
+ // 2 lsbs of upper as v msbs
82
+ v = (vLower >> 4) | ((vUpper & 0x3) << 4);
83
+ break;
84
+ case 3:
85
+ // 6 msbs of lower
86
+ v = (vLower >> 2);
87
+ break;
88
+ }
92
89
 
93
- std::vector<uint8_t>
94
- unpackNonInterleaved(std::vector<uint8_t> data,
95
- int numVecs,
96
- int dims,
97
- int bitsPerCode) {
98
- int srcVecSize = utils::divUp(dims * bitsPerCode, 8);
99
- FAISS_ASSERT(data.size() == numVecs * srcVecSize);
90
+ return v;
91
+ }
100
92
 
101
- if (bitsPerCode == 8 ||
102
- bitsPerCode == 16 ||
103
- bitsPerCode == 32) {
104
- // nothing to do
105
- return data;
106
- }
93
+ std::vector<uint8_t> unpackNonInterleaved(
94
+ std::vector<uint8_t> data,
95
+ int numVecs,
96
+ int dims,
97
+ int bitsPerCode) {
98
+ int srcVecSize = utils::divUp(dims * bitsPerCode, 8);
99
+ FAISS_ASSERT(data.size() == numVecs * srcVecSize);
100
+
101
+ if (bitsPerCode == 8 || bitsPerCode == 16 || bitsPerCode == 32) {
102
+ // nothing to do
103
+ return data;
104
+ }
107
105
 
108
- // bit codes padded to whole bytes
109
- std::vector<uint8_t> out(numVecs * dims * utils::divUp(bitsPerCode, 8));
106
+ // bit codes padded to whole bytes
107
+ std::vector<uint8_t> out(numVecs * dims * utils::divUp(bitsPerCode, 8));
110
108
 
111
- if (bitsPerCode == 4) {
109
+ if (bitsPerCode == 4) {
112
110
  #pragma omp parallel for
113
- for (int i = 0; i < numVecs; ++i) {
114
- for (int j = 0; j < dims; ++j) {
115
- int srcIdx = i * srcVecSize + (j / 2);
116
- FAISS_ASSERT(srcIdx < data.size());
111
+ for (int i = 0; i < numVecs; ++i) {
112
+ for (int j = 0; j < dims; ++j) {
113
+ int srcIdx = i * srcVecSize + (j / 2);
114
+ FAISS_ASSERT(srcIdx < data.size());
117
115
 
118
- uint8_t v = data[srcIdx];
119
- v = (j % 2 == 0) ? v & 0xf : v >> 4;
116
+ uint8_t v = data[srcIdx];
117
+ v = (j % 2 == 0) ? v & 0xf : v >> 4;
120
118
 
121
- out[i * dims + j] = v;
122
- }
123
- }
124
- } else if (bitsPerCode == 5) {
119
+ out[i * dims + j] = v;
120
+ }
121
+ }
122
+ } else if (bitsPerCode == 5) {
125
123
  #pragma omp parallel for
126
- for (int i = 0; i < numVecs; ++i) {
127
- for (int j = 0; j < dims; ++j) {
128
- int lo = i * srcVecSize + (j * 5) / 8;
129
- int hi = lo + 1;
124
+ for (int i = 0; i < numVecs; ++i) {
125
+ for (int j = 0; j < dims; ++j) {
126
+ int lo = i * srcVecSize + (j * 5) / 8;
127
+ int hi = lo + 1;
130
128
 
131
- FAISS_ASSERT(lo < data.size());
132
- FAISS_ASSERT(hi <= data.size());
129
+ FAISS_ASSERT(lo < data.size());
130
+ FAISS_ASSERT(hi <= data.size());
133
131
 
134
- auto vLower = data[lo];
135
- auto vUpper = hi < data.size() ? data[hi] : 0;
132
+ auto vLower = data[lo];
133
+ auto vUpper = hi < data.size() ? data[hi] : 0;
136
134
 
137
- out[i * dims + j] = unpack5(j, vLower, vUpper);
138
- }
139
- }
140
- } else if (bitsPerCode == 6) {
135
+ out[i * dims + j] = unpack5(j, vLower, vUpper);
136
+ }
137
+ }
138
+ } else if (bitsPerCode == 6) {
141
139
  #pragma omp parallel for
142
- for (int i = 0; i < numVecs; ++i) {
143
- for (int j = 0; j < dims; ++j) {
144
- int lo = i * srcVecSize + (j * 6) / 8;
145
- int hi = lo + 1;
140
+ for (int i = 0; i < numVecs; ++i) {
141
+ for (int j = 0; j < dims; ++j) {
142
+ int lo = i * srcVecSize + (j * 6) / 8;
143
+ int hi = lo + 1;
146
144
 
147
- FAISS_ASSERT(lo < data.size());
148
- FAISS_ASSERT(hi <= data.size());
145
+ FAISS_ASSERT(lo < data.size());
146
+ FAISS_ASSERT(hi <= data.size());
149
147
 
150
- auto vLower = data[lo];
151
- auto vUpper = hi < data.size() ? data[hi] : 0;
148
+ auto vLower = data[lo];
149
+ auto vUpper = hi < data.size() ? data[hi] : 0;
152
150
 
153
- out[i * dims + j] = unpack6(j, vLower, vUpper);
154
- }
151
+ out[i * dims + j] = unpack6(j, vLower, vUpper);
152
+ }
153
+ }
154
+ } else {
155
+ // unhandled
156
+ FAISS_ASSERT(false);
155
157
  }
156
- } else {
157
- // unhandled
158
- FAISS_ASSERT(false);
159
- }
160
158
 
161
- return out;
159
+ return out;
162
160
  }
163
161
 
164
162
  template <typename T>
165
- void
166
- unpackInterleavedWord(const T* in,
167
- T* out,
168
- int numVecs,
169
- int dims,
170
- int bitsPerCode) {
171
- int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
172
- int wordsPerBlock = wordsPerDimBlock * dims;
173
- int numBlocks = utils::divUp(numVecs, 32);
163
+ void unpackInterleavedWord(
164
+ const T* in,
165
+ T* out,
166
+ int numVecs,
167
+ int dims,
168
+ int bitsPerCode) {
169
+ int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
170
+ int wordsPerBlock = wordsPerDimBlock * dims;
171
+ int numBlocks = utils::divUp(numVecs, 32);
174
172
 
175
173
  #pragma omp parallel for
176
- for (int i = 0; i < numVecs; ++i) {
177
- int block = i / 32;
178
- FAISS_ASSERT(block < numBlocks);
179
- int lane = i % 32;
180
-
181
- for (int j = 0; j < dims; ++j) {
182
- int srcOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
183
- out[i * dims + j] = in[srcOffset];
174
+ for (int i = 0; i < numVecs; ++i) {
175
+ int block = i / 32;
176
+ FAISS_ASSERT(block < numBlocks);
177
+ int lane = i % 32;
178
+
179
+ for (int j = 0; j < dims; ++j) {
180
+ int srcOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
181
+ out[i * dims + j] = in[srcOffset];
182
+ }
184
183
  }
185
- }
186
184
  }
187
185
 
188
- std::vector<uint8_t>
189
- unpackInterleaved(std::vector<uint8_t> data,
190
- int numVecs,
191
- int dims,
192
- int bitsPerCode) {
193
- int bytesPerDimBlock = 32 * bitsPerCode / 8;
194
- int bytesPerBlock = bytesPerDimBlock * dims;
195
- int numBlocks = utils::divUp(numVecs, 32);
196
- size_t totalSize = (size_t) bytesPerBlock * numBlocks;
197
- FAISS_ASSERT(data.size() == totalSize);
198
-
199
- // bit codes padded to whole bytes
200
- std::vector<uint8_t> out(numVecs * dims * utils::divUp(bitsPerCode, 8));
201
-
202
- if (bitsPerCode == 8) {
203
- unpackInterleavedWord<uint8_t>(data.data(), out.data(),
204
- numVecs, dims, bitsPerCode);
205
- } else if (bitsPerCode == 16) {
206
- unpackInterleavedWord<uint16_t>((uint16_t*) data.data(),
207
- (uint16_t*) out.data(),
208
- numVecs, dims, bitsPerCode);
209
- } else if (bitsPerCode == 32) {
210
- unpackInterleavedWord<uint32_t>((uint32_t*) data.data(),
211
- (uint32_t*) out.data(),
212
- numVecs, dims, bitsPerCode);
213
- } else if (bitsPerCode == 4) {
186
+ std::vector<uint8_t> unpackInterleaved(
187
+ std::vector<uint8_t> data,
188
+ int numVecs,
189
+ int dims,
190
+ int bitsPerCode) {
191
+ int bytesPerDimBlock = 32 * bitsPerCode / 8;
192
+ int bytesPerBlock = bytesPerDimBlock * dims;
193
+ int numBlocks = utils::divUp(numVecs, 32);
194
+ size_t totalSize = (size_t)bytesPerBlock * numBlocks;
195
+ FAISS_ASSERT(data.size() == totalSize);
196
+
197
+ // bit codes padded to whole bytes
198
+ std::vector<uint8_t> out(numVecs * dims * utils::divUp(bitsPerCode, 8));
199
+
200
+ if (bitsPerCode == 8) {
201
+ unpackInterleavedWord<uint8_t>(
202
+ data.data(), out.data(), numVecs, dims, bitsPerCode);
203
+ } else if (bitsPerCode == 16) {
204
+ unpackInterleavedWord<uint16_t>(
205
+ (uint16_t*)data.data(),
206
+ (uint16_t*)out.data(),
207
+ numVecs,
208
+ dims,
209
+ bitsPerCode);
210
+ } else if (bitsPerCode == 32) {
211
+ unpackInterleavedWord<uint32_t>(
212
+ (uint32_t*)data.data(),
213
+ (uint32_t*)out.data(),
214
+ numVecs,
215
+ dims,
216
+ bitsPerCode);
217
+ } else if (bitsPerCode == 4) {
214
218
  #pragma omp parallel for
215
- for (int i = 0; i < numVecs; ++i) {
216
- int block = i / 32;
217
- int lane = i % 32;
219
+ for (int i = 0; i < numVecs; ++i) {
220
+ int block = i / 32;
221
+ int lane = i % 32;
218
222
 
219
- int word = lane / 2;
220
- int subWord = lane % 2;
223
+ int word = lane / 2;
224
+ int subWord = lane % 2;
221
225
 
222
- for (int j = 0; j < dims; ++j) {
223
- auto v =
224
- data[block * bytesPerBlock + j * bytesPerDimBlock + word];
226
+ for (int j = 0; j < dims; ++j) {
227
+ auto v =
228
+ data[block * bytesPerBlock + j * bytesPerDimBlock +
229
+ word];
225
230
 
226
- v = (subWord == 0) ? v & 0xf : v >> 4;
227
- out[i * dims + j] = v;
228
- }
229
- }
230
- } else if (bitsPerCode == 5) {
231
+ v = (subWord == 0) ? v & 0xf : v >> 4;
232
+ out[i * dims + j] = v;
233
+ }
234
+ }
235
+ } else if (bitsPerCode == 5) {
231
236
  #pragma omp parallel for
232
- for (int i = 0; i < numVecs; ++i) {
233
- int block = i / 32;
234
- int blockVector = i % 32;
237
+ for (int i = 0; i < numVecs; ++i) {
238
+ int block = i / 32;
239
+ int blockVector = i % 32;
235
240
 
236
- for (int j = 0; j < dims; ++j) {
237
- uint8_t* dimBlock =
238
- &data[block * bytesPerBlock + j * bytesPerDimBlock];
241
+ for (int j = 0; j < dims; ++j) {
242
+ uint8_t* dimBlock =
243
+ &data[block * bytesPerBlock + j * bytesPerDimBlock];
239
244
 
240
- int lo = (blockVector * 5) / 8;
241
- int hi = lo + 1;
245
+ int lo = (blockVector * 5) / 8;
246
+ int hi = lo + 1;
242
247
 
243
- FAISS_ASSERT(lo < bytesPerDimBlock);
244
- FAISS_ASSERT(hi <= bytesPerDimBlock);
248
+ FAISS_ASSERT(lo < bytesPerDimBlock);
249
+ FAISS_ASSERT(hi <= bytesPerDimBlock);
245
250
 
246
- auto vLower = dimBlock[lo];
247
- auto vUpper = hi < bytesPerDimBlock ? dimBlock[hi] : 0;
251
+ auto vLower = dimBlock[lo];
252
+ auto vUpper = hi < bytesPerDimBlock ? dimBlock[hi] : 0;
248
253
 
249
- out[i * dims + j] = unpack5(blockVector, vLower, vUpper);
250
- }
251
- }
252
- } else if (bitsPerCode == 6) {
254
+ out[i * dims + j] = unpack5(blockVector, vLower, vUpper);
255
+ }
256
+ }
257
+ } else if (bitsPerCode == 6) {
253
258
  #pragma omp parallel for
254
- for (int i = 0; i < numVecs; ++i) {
255
- int block = i / 32;
256
- int blockVector = i % 32;
259
+ for (int i = 0; i < numVecs; ++i) {
260
+ int block = i / 32;
261
+ int blockVector = i % 32;
257
262
 
258
- for (int j = 0; j < dims; ++j) {
259
- uint8_t* dimBlock =
260
- &data[block * bytesPerBlock + j * bytesPerDimBlock];
263
+ for (int j = 0; j < dims; ++j) {
264
+ uint8_t* dimBlock =
265
+ &data[block * bytesPerBlock + j * bytesPerDimBlock];
261
266
 
262
- int lo = (blockVector * 6) / 8;
263
- int hi = lo + 1;
267
+ int lo = (blockVector * 6) / 8;
268
+ int hi = lo + 1;
264
269
 
265
- FAISS_ASSERT(lo < bytesPerDimBlock);
266
- FAISS_ASSERT(hi <= bytesPerDimBlock);
270
+ FAISS_ASSERT(lo < bytesPerDimBlock);
271
+ FAISS_ASSERT(hi <= bytesPerDimBlock);
267
272
 
268
- auto vLower = dimBlock[lo];
269
- auto vUpper = hi < bytesPerDimBlock ? dimBlock[hi] : 0;
273
+ auto vLower = dimBlock[lo];
274
+ auto vUpper = hi < bytesPerDimBlock ? dimBlock[hi] : 0;
270
275
 
271
- out[i * dims + j] = unpack6(blockVector, vLower, vUpper);
272
- }
276
+ out[i * dims + j] = unpack6(blockVector, vLower, vUpper);
277
+ }
278
+ }
279
+ } else {
280
+ // unimplemented
281
+ FAISS_ASSERT(false);
273
282
  }
274
- } else {
275
- // unimplemented
276
- FAISS_ASSERT(false);
277
- }
278
283
 
279
- return out;
284
+ return out;
280
285
  }
281
286
 
282
287
  inline uint8_t pack5(int i, uint8_t lo, uint8_t hi, uint8_t hi2) {
283
- FAISS_ASSERT((lo & 0x1f) == lo);
284
- FAISS_ASSERT((hi & 0x1f) == hi);
285
- FAISS_ASSERT((hi2 & 0x1f) == hi2);
286
-
287
- uint8_t v = 0;
288
-
289
- // lsb ... msb
290
- // 0: 0 0 0 0 0 1 1 1
291
- // 1: 1 1 2 2 2 2 2 3
292
- // 2: 3 3 3 3 4 4 4 4
293
- // 3: 4 5 5 5 5 5 6 6
294
- // 4: 6 6 6 7 7 7 7 7
295
- switch (i % 5) {
296
- case 0:
297
- // 5 msbs of lower as vOut lsbs
298
- // 3 lsbs of upper as vOut msbs
299
- v = (lo & 0x1f) | (hi << 5);
300
- break;
301
- case 1:
302
- // 2 msbs of lower as vOut lsbs
303
- // 5 lsbs of upper as vOut msbs
304
- // 1 lsbs of upper2 as vOut msb
305
- v = (lo >> 3) | (hi << 2) | (hi2 << 7);
306
- break;
307
- case 2:
308
- // 4 msbs of lower as vOut lsbs
309
- // 4 lsbs of upper as vOut msbs
310
- v = (lo >> 1) | (hi << 4);
311
- break;
312
- case 3:
313
- // 1 msbs of lower as vOut lsbs
314
- // 5 lsbs of upper as vOut msbs
315
- // 2 lsbs of upper2 as vOut msb
316
- v = (lo >> 4) | (hi << 1) | (hi2 << 6);
317
- break;
318
- case 4:
319
- // 3 msbs of lower as vOut lsbs
320
- // 5 lsbs of upper as vOut msbs
321
- v = (lo >> 2) | (hi << 3);
322
- break;
323
- }
324
-
325
- return v;
326
- }
288
+ FAISS_ASSERT((lo & 0x1f) == lo);
289
+ FAISS_ASSERT((hi & 0x1f) == hi);
290
+ FAISS_ASSERT((hi2 & 0x1f) == hi2);
291
+
292
+ uint8_t v = 0;
293
+
294
+ // lsb ... msb
295
+ // 0: 0 0 0 0 0 1 1 1
296
+ // 1: 1 1 2 2 2 2 2 3
297
+ // 2: 3 3 3 3 4 4 4 4
298
+ // 3: 4 5 5 5 5 5 6 6
299
+ // 4: 6 6 6 7 7 7 7 7
300
+ switch (i % 5) {
301
+ case 0:
302
+ // 5 msbs of lower as vOut lsbs
303
+ // 3 lsbs of upper as vOut msbs
304
+ v = (lo & 0x1f) | (hi << 5);
305
+ break;
306
+ case 1:
307
+ // 2 msbs of lower as vOut lsbs
308
+ // 5 lsbs of upper as vOut msbs
309
+ // 1 lsbs of upper2 as vOut msb
310
+ v = (lo >> 3) | (hi << 2) | (hi2 << 7);
311
+ break;
312
+ case 2:
313
+ // 4 msbs of lower as vOut lsbs
314
+ // 4 lsbs of upper as vOut msbs
315
+ v = (lo >> 1) | (hi << 4);
316
+ break;
317
+ case 3:
318
+ // 1 msbs of lower as vOut lsbs
319
+ // 5 lsbs of upper as vOut msbs
320
+ // 2 lsbs of upper2 as vOut msb
321
+ v = (lo >> 4) | (hi << 1) | (hi2 << 6);
322
+ break;
323
+ case 4:
324
+ // 3 msbs of lower as vOut lsbs
325
+ // 5 lsbs of upper as vOut msbs
326
+ v = (lo >> 2) | (hi << 3);
327
+ break;
328
+ }
327
329
 
328
- inline uint8_t pack6(int i, uint8_t lo, uint8_t hi) {
329
- FAISS_ASSERT((lo & 0x3f) == lo);
330
- FAISS_ASSERT((hi & 0x3f) == hi);
331
-
332
- uint8_t v = 0;
333
-
334
- // lsb ... msb
335
- // 0: 0 0 0 0 0 0 1 1
336
- // 1: 1 1 1 1 2 2 2 2
337
- // 2: 2 2 3 3 3 3 3 3
338
- switch (i % 3) {
339
- case 0:
340
- // 6 msbs of lower as vOut lsbs
341
- // 2 lsbs of upper as vOut msbs
342
- v = (lo & 0x3f) | (hi << 6);
343
- break;
344
- case 1:
345
- // 4 msbs of lower as vOut lsbs
346
- // 4 lsbs of upper as vOut msbs
347
- v = (lo >> 2) | (hi << 4);
348
- break;
349
- case 2:
350
- // 2 msbs of lower as vOut lsbs
351
- // 6 lsbs of upper as vOut msbs
352
- v = (lo >> 4) | (hi << 2);
353
- break;
354
- }
355
-
356
- return v;
330
+ return v;
357
331
  }
358
332
 
333
+ inline uint8_t pack6(int i, uint8_t lo, uint8_t hi) {
334
+ FAISS_ASSERT((lo & 0x3f) == lo);
335
+ FAISS_ASSERT((hi & 0x3f) == hi);
336
+
337
+ uint8_t v = 0;
338
+
339
+ // lsb ... msb
340
+ // 0: 0 0 0 0 0 0 1 1
341
+ // 1: 1 1 1 1 2 2 2 2
342
+ // 2: 2 2 3 3 3 3 3 3
343
+ switch (i % 3) {
344
+ case 0:
345
+ // 6 msbs of lower as vOut lsbs
346
+ // 2 lsbs of upper as vOut msbs
347
+ v = (lo & 0x3f) | (hi << 6);
348
+ break;
349
+ case 1:
350
+ // 4 msbs of lower as vOut lsbs
351
+ // 4 lsbs of upper as vOut msbs
352
+ v = (lo >> 2) | (hi << 4);
353
+ break;
354
+ case 2:
355
+ // 2 msbs of lower as vOut lsbs
356
+ // 6 lsbs of upper as vOut msbs
357
+ v = (lo >> 4) | (hi << 2);
358
+ break;
359
+ }
359
360
 
360
- std::vector<uint8_t>
361
- packNonInterleaved(std::vector<uint8_t> data,
362
- int numVecs,
363
- int dims,
364
- int bitsPerCode) {
365
- // bit codes padded to whole bytes
366
- FAISS_ASSERT(data.size() == numVecs * dims * utils::divUp(bitsPerCode, 8));
361
+ return v;
362
+ }
367
363
 
368
- if (bitsPerCode == 8 ||
369
- bitsPerCode == 16 ||
370
- bitsPerCode == 32) {
371
- // nothing to do, whole words are already where they need to be
372
- return data;
373
- }
364
+ std::vector<uint8_t> packNonInterleaved(
365
+ std::vector<uint8_t> data,
366
+ int numVecs,
367
+ int dims,
368
+ int bitsPerCode) {
369
+ // bit codes padded to whole bytes
370
+ FAISS_ASSERT(data.size() == numVecs * dims * utils::divUp(bitsPerCode, 8));
371
+
372
+ if (bitsPerCode == 8 || bitsPerCode == 16 || bitsPerCode == 32) {
373
+ // nothing to do, whole words are already where they need to be
374
+ return data;
375
+ }
374
376
 
375
- // bits packed into a whole number of bytes
376
- int bytesPerVec = utils::divUp(dims * bitsPerCode, 8);
377
+ // bits packed into a whole number of bytes
378
+ int bytesPerVec = utils::divUp(dims * bitsPerCode, 8);
377
379
 
378
- std::vector<uint8_t> out(numVecs * bytesPerVec);
380
+ std::vector<uint8_t> out(numVecs * bytesPerVec);
379
381
 
380
- if (bitsPerCode == 4) {
382
+ if (bitsPerCode == 4) {
381
383
  #pragma omp parallel for
382
- for (int i = 0; i < numVecs; ++i) {
383
- for (int j = 0; j < bytesPerVec; ++j) {
384
- int dimLo = j * 2;
385
- int dimHi = dimLo + 1;
386
- FAISS_ASSERT(dimLo < dims);
387
- FAISS_ASSERT(dimHi <= dims);
388
-
389
- uint8_t lo = data[i * dims + dimLo];
390
- uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
391
-
392
- out[i * bytesPerVec + j] = (hi << 4) | (lo & 0xf);
393
- }
394
- }
395
- } else if (bitsPerCode == 5) {
384
+ for (int i = 0; i < numVecs; ++i) {
385
+ for (int j = 0; j < bytesPerVec; ++j) {
386
+ int dimLo = j * 2;
387
+ int dimHi = dimLo + 1;
388
+ FAISS_ASSERT(dimLo < dims);
389
+ FAISS_ASSERT(dimHi <= dims);
390
+
391
+ uint8_t lo = data[i * dims + dimLo];
392
+ uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
393
+
394
+ out[i * bytesPerVec + j] = (hi << 4) | (lo & 0xf);
395
+ }
396
+ }
397
+ } else if (bitsPerCode == 5) {
396
398
  #pragma omp parallel for
397
- for (int i = 0; i < numVecs; ++i) {
398
- for (int j = 0; j < bytesPerVec; ++j) {
399
- int dimLo = (j * 8) / 5;
400
- int dimHi = dimLo + 1;
401
- int dimHi2 = dimHi + 1;
402
- FAISS_ASSERT(dimLo < dims);
403
- FAISS_ASSERT(dimHi <= dims);
404
- FAISS_ASSERT(dimHi <= dims + 1);
405
-
406
- uint8_t lo = data[i * dims + dimLo];
407
- uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
408
- uint8_t hi2 = dimHi2 < dims ? data[i * dims + dimHi2] : 0;
409
-
410
- out[i * bytesPerVec + j] = pack5(j, lo, hi, hi2);
411
- }
412
- }
413
- } else if (bitsPerCode == 6) {
399
+ for (int i = 0; i < numVecs; ++i) {
400
+ for (int j = 0; j < bytesPerVec; ++j) {
401
+ int dimLo = (j * 8) / 5;
402
+ int dimHi = dimLo + 1;
403
+ int dimHi2 = dimHi + 1;
404
+ FAISS_ASSERT(dimLo < dims);
405
+ FAISS_ASSERT(dimHi <= dims);
406
+ FAISS_ASSERT(dimHi <= dims + 1);
407
+
408
+ uint8_t lo = data[i * dims + dimLo];
409
+ uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
410
+ uint8_t hi2 = dimHi2 < dims ? data[i * dims + dimHi2] : 0;
411
+
412
+ out[i * bytesPerVec + j] = pack5(j, lo, hi, hi2);
413
+ }
414
+ }
415
+ } else if (bitsPerCode == 6) {
414
416
  #pragma omp parallel for
415
- for (int i = 0; i < numVecs; ++i) {
416
- for (int j = 0; j < bytesPerVec; ++j) {
417
- int dimLo = (j * 8) / 6;
418
- int dimHi = dimLo + 1;
419
- FAISS_ASSERT(dimLo < dims);
420
- FAISS_ASSERT(dimHi <= dims);
421
-
422
- uint8_t lo = data[i * dims + dimLo];
423
- uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
424
-
425
- out[i * bytesPerVec + j] = pack6(j, lo, hi);
426
- }
417
+ for (int i = 0; i < numVecs; ++i) {
418
+ for (int j = 0; j < bytesPerVec; ++j) {
419
+ int dimLo = (j * 8) / 6;
420
+ int dimHi = dimLo + 1;
421
+ FAISS_ASSERT(dimLo < dims);
422
+ FAISS_ASSERT(dimHi <= dims);
423
+
424
+ uint8_t lo = data[i * dims + dimLo];
425
+ uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
426
+
427
+ out[i * bytesPerVec + j] = pack6(j, lo, hi);
428
+ }
429
+ }
430
+ } else {
431
+ // unhandled
432
+ FAISS_ASSERT(false);
427
433
  }
428
- } else {
429
- // unhandled
430
- FAISS_ASSERT(false);
431
- }
432
434
 
433
- return out;
435
+ return out;
434
436
  }
435
437
 
436
438
  template <typename T>
437
- void
438
- packInterleavedWord(const T* in,
439
- T* out,
440
- int numVecs,
441
- int dims,
442
- int bitsPerCode) {
443
- int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
444
- int wordsPerBlock = wordsPerDimBlock * dims;
445
- int numBlocks = utils::divUp(numVecs, 32);
446
-
447
- // We're guaranteed that all other slots not filled by the vectors present are
448
- // initialized to zero (from the vector constructor in packInterleaved)
439
+ void packInterleavedWord(
440
+ const T* in,
441
+ T* out,
442
+ int numVecs,
443
+ int dims,
444
+ int bitsPerCode) {
445
+ int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
446
+ int wordsPerBlock = wordsPerDimBlock * dims;
447
+ int numBlocks = utils::divUp(numVecs, 32);
448
+
449
+ // We're guaranteed that all other slots not filled by the vectors present
450
+ // are initialized to zero (from the vector constructor in packInterleaved)
449
451
  #pragma omp parallel for
450
- for (int i = 0; i < numVecs; ++i) {
451
- int block = i / 32;
452
- FAISS_ASSERT(block < numBlocks);
453
- int lane = i % 32;
454
-
455
- for (int j = 0; j < dims; ++j) {
456
- int dstOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
457
- out[dstOffset] = in[i * dims + j];
452
+ for (int i = 0; i < numVecs; ++i) {
453
+ int block = i / 32;
454
+ FAISS_ASSERT(block < numBlocks);
455
+ int lane = i % 32;
456
+
457
+ for (int j = 0; j < dims; ++j) {
458
+ int dstOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
459
+ out[dstOffset] = in[i * dims + j];
460
+ }
458
461
  }
459
- }
460
462
  }
461
463
 
462
- std::vector<uint8_t>
463
- packInterleaved(std::vector<uint8_t> data,
464
- int numVecs,
465
- int dims,
466
- int bitsPerCode) {
467
- int bytesPerDimBlock = 32 * bitsPerCode / 8;
468
- int bytesPerBlock = bytesPerDimBlock * dims;
469
- int numBlocks = utils::divUp(numVecs, 32);
470
- size_t totalSize = (size_t) bytesPerBlock * numBlocks;
471
-
472
- // bit codes padded to whole bytes
473
- FAISS_ASSERT(data.size() == numVecs * dims * utils::divUp(bitsPerCode, 8));
474
-
475
- // packs based on blocks
476
- std::vector<uint8_t> out(totalSize, 0);
477
-
478
- if (bitsPerCode == 8) {
479
- packInterleavedWord<uint8_t>(data.data(), out.data(),
480
- numVecs, dims, bitsPerCode);
481
- } else if (bitsPerCode == 16) {
482
- packInterleavedWord<uint16_t>((uint16_t*) data.data(),
483
- (uint16_t*) out.data(),
484
- numVecs, dims, bitsPerCode);
485
- } else if (bitsPerCode == 32) {
486
- packInterleavedWord<uint32_t>((uint32_t*) data.data(),
487
- (uint32_t*) out.data(),
488
- numVecs, dims, bitsPerCode);
489
- } else if (bitsPerCode == 4) {
464
+ std::vector<uint8_t> packInterleaved(
465
+ std::vector<uint8_t> data,
466
+ int numVecs,
467
+ int dims,
468
+ int bitsPerCode) {
469
+ int bytesPerDimBlock = 32 * bitsPerCode / 8;
470
+ int bytesPerBlock = bytesPerDimBlock * dims;
471
+ int numBlocks = utils::divUp(numVecs, 32);
472
+ size_t totalSize = (size_t)bytesPerBlock * numBlocks;
473
+
474
+ // bit codes padded to whole bytes
475
+ FAISS_ASSERT(data.size() == numVecs * dims * utils::divUp(bitsPerCode, 8));
476
+
477
+ // packs based on blocks
478
+ std::vector<uint8_t> out(totalSize, 0);
479
+
480
+ if (bitsPerCode == 8) {
481
+ packInterleavedWord<uint8_t>(
482
+ data.data(), out.data(), numVecs, dims, bitsPerCode);
483
+ } else if (bitsPerCode == 16) {
484
+ packInterleavedWord<uint16_t>(
485
+ (uint16_t*)data.data(),
486
+ (uint16_t*)out.data(),
487
+ numVecs,
488
+ dims,
489
+ bitsPerCode);
490
+ } else if (bitsPerCode == 32) {
491
+ packInterleavedWord<uint32_t>(
492
+ (uint32_t*)data.data(),
493
+ (uint32_t*)out.data(),
494
+ numVecs,
495
+ dims,
496
+ bitsPerCode);
497
+ } else if (bitsPerCode == 4) {
490
498
  #pragma omp parallel for
491
- for (int i = 0; i < numBlocks; ++i) {
492
- for (int j = 0; j < dims; ++j) {
493
- for (int k = 0; k < bytesPerDimBlock; ++k) {
494
- int loVec = i * 32 + k * 2;
495
- int hiVec = loVec + 1;
496
-
497
- uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
498
- uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
499
-
500
- out[i * bytesPerBlock + j * bytesPerDimBlock + k] =
501
- (hi << 4) | (lo & 0xf);
499
+ for (int i = 0; i < numBlocks; ++i) {
500
+ for (int j = 0; j < dims; ++j) {
501
+ for (int k = 0; k < bytesPerDimBlock; ++k) {
502
+ int loVec = i * 32 + k * 2;
503
+ int hiVec = loVec + 1;
504
+
505
+ uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
506
+ uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
507
+
508
+ out[i * bytesPerBlock + j * bytesPerDimBlock + k] =
509
+ (hi << 4) | (lo & 0xf);
510
+ }
511
+ }
502
512
  }
503
- }
504
- }
505
- } else if (bitsPerCode == 5) {
513
+ } else if (bitsPerCode == 5) {
506
514
  #pragma omp parallel for
507
- for (int i = 0; i < numBlocks; ++i) {
508
- for (int j = 0; j < dims; ++j) {
509
- for (int k = 0; k < bytesPerDimBlock; ++k) {
510
- // What input vectors we are pulling from
511
- int loVec = i * 32 + (k * 8) / 5;
512
- int hiVec = loVec + 1;
513
- int hiVec2 = hiVec + 1;
514
-
515
- uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
516
- uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
517
- uint8_t hi2 = hiVec2 < numVecs ? data[hiVec2 * dims + j] : 0;
518
-
519
- out[i * bytesPerBlock + j * bytesPerDimBlock + k] = pack5(k, lo, hi, hi2);
515
+ for (int i = 0; i < numBlocks; ++i) {
516
+ for (int j = 0; j < dims; ++j) {
517
+ for (int k = 0; k < bytesPerDimBlock; ++k) {
518
+ // What input vectors we are pulling from
519
+ int loVec = i * 32 + (k * 8) / 5;
520
+ int hiVec = loVec + 1;
521
+ int hiVec2 = hiVec + 1;
522
+
523
+ uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
524
+ uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
525
+ uint8_t hi2 =
526
+ hiVec2 < numVecs ? data[hiVec2 * dims + j] : 0;
527
+
528
+ out[i * bytesPerBlock + j * bytesPerDimBlock + k] =
529
+ pack5(k, lo, hi, hi2);
530
+ }
531
+ }
520
532
  }
521
- }
522
- }
523
- } else if (bitsPerCode == 6) {
533
+ } else if (bitsPerCode == 6) {
524
534
  #pragma omp parallel for
525
- for (int i = 0; i < numBlocks; ++i) {
526
- for (int j = 0; j < dims; ++j) {
527
- for (int k = 0; k < bytesPerDimBlock; ++k) {
528
- // What input vectors we are pulling from
529
- int loVec = i * 32 + (k * 8) / 6;
530
- int hiVec = loVec + 1;
531
-
532
- uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
533
- uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
534
-
535
- out[i * bytesPerBlock + j * bytesPerDimBlock + k] = pack6(k, lo, hi);
535
+ for (int i = 0; i < numBlocks; ++i) {
536
+ for (int j = 0; j < dims; ++j) {
537
+ for (int k = 0; k < bytesPerDimBlock; ++k) {
538
+ // What input vectors we are pulling from
539
+ int loVec = i * 32 + (k * 8) / 6;
540
+ int hiVec = loVec + 1;
541
+
542
+ uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
543
+ uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
544
+
545
+ out[i * bytesPerBlock + j * bytesPerDimBlock + k] =
546
+ pack6(k, lo, hi);
547
+ }
548
+ }
536
549
  }
537
- }
550
+ } else {
551
+ // unimplemented
552
+ FAISS_ASSERT(false);
538
553
  }
539
- } else {
540
- // unimplemented
541
- FAISS_ASSERT(false);
542
- }
543
554
 
544
- return out;
555
+ return out;
545
556
  }
546
557
 
547
- } } // namespace
558
+ } // namespace gpu
559
+ } // namespace faiss