faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -24,18 +24,17 @@ namespace {
24
24
 
25
25
  typedef Index::idx_t idx_t;
26
26
 
27
-
28
27
  // add translation to all valid labels
29
- void translate_labels (long n, idx_t *labels, long translation)
30
- {
31
- if (translation == 0) return;
28
+ void translate_labels(long n, idx_t* labels, long translation) {
29
+ if (translation == 0)
30
+ return;
32
31
  for (long i = 0; i < n; i++) {
33
- if(labels[i] < 0) continue;
32
+ if (labels[i] < 0)
33
+ continue;
34
34
  labels[i] += translation;
35
35
  }
36
36
  }
37
37
 
38
-
39
38
  /** merge result tables from several shards.
40
39
  * @param all_distances size nshard * n * k
41
40
  * @param all_labels idem
@@ -43,296 +42,313 @@ void translate_labels (long n, idx_t *labels, long translation)
43
42
  */
44
43
 
45
44
  template <class IndexClass, class C>
46
- void
47
- merge_tables(long n, long k, long nshard,
48
- typename IndexClass::distance_t *distances,
49
- idx_t *labels,
50
- const std::vector<typename IndexClass::distance_t>& all_distances,
51
- const std::vector<idx_t>& all_labels,
52
- const std::vector<long>& translations) {
53
- if (k == 0) {
54
- return;
55
- }
56
- using distance_t = typename IndexClass::distance_t;
57
-
58
- long stride = n * k;
45
+ void merge_tables(
46
+ long n,
47
+ long k,
48
+ long nshard,
49
+ typename IndexClass::distance_t* distances,
50
+ idx_t* labels,
51
+ const std::vector<typename IndexClass::distance_t>& all_distances,
52
+ const std::vector<idx_t>& all_labels,
53
+ const std::vector<long>& translations) {
54
+ if (k == 0) {
55
+ return;
56
+ }
57
+ using distance_t = typename IndexClass::distance_t;
58
+
59
+ long stride = n * k;
59
60
  #pragma omp parallel
60
- {
61
- std::vector<int> buf (2 * nshard);
62
- int * pointer = buf.data();
63
- int * shard_ids = pointer + nshard;
64
- std::vector<distance_t> buf2 (nshard);
65
- distance_t * heap_vals = buf2.data();
61
+ {
62
+ std::vector<int> buf(2 * nshard);
63
+ int* pointer = buf.data();
64
+ int* shard_ids = pointer + nshard;
65
+ std::vector<distance_t> buf2(nshard);
66
+ distance_t* heap_vals = buf2.data();
66
67
  #pragma omp for
67
- for (long i = 0; i < n; i++) {
68
- // the heap maps values to the shard where they are
69
- // produced.
70
- const distance_t *D_in = all_distances.data() + i * k;
71
- const idx_t *I_in = all_labels.data() + i * k;
72
- int heap_size = 0;
73
-
74
- for (long s = 0; s < nshard; s++) {
75
- pointer[s] = 0;
76
- if (I_in[stride * s] >= 0) {
77
- heap_push<C> (++heap_size, heap_vals, shard_ids,
78
- D_in[stride * s], s);
79
- }
80
- }
81
-
82
- distance_t *D = distances + i * k;
83
- idx_t *I = labels + i * k;
84
-
85
- for (int j = 0; j < k; j++) {
86
- if (heap_size == 0) {
87
- I[j] = -1;
88
- D[j] = C::neutral();
89
- } else {
90
- // pop best element
91
- int s = shard_ids[0];
92
- int & p = pointer[s];
93
- D[j] = heap_vals[0];
94
- I[j] = I_in[stride * s + p] + translations[s];
95
-
96
- heap_pop<C> (heap_size--, heap_vals, shard_ids);
97
- p++;
98
- if (p < k && I_in[stride * s + p] >= 0) {
99
- heap_push<C> (++heap_size, heap_vals, shard_ids,
100
- D_in[stride * s + p], s);
101
- }
68
+ for (long i = 0; i < n; i++) {
69
+ // the heap maps values to the shard where they are
70
+ // produced.
71
+ const distance_t* D_in = all_distances.data() + i * k;
72
+ const idx_t* I_in = all_labels.data() + i * k;
73
+ int heap_size = 0;
74
+
75
+ for (long s = 0; s < nshard; s++) {
76
+ pointer[s] = 0;
77
+ if (I_in[stride * s] >= 0) {
78
+ heap_push<C>(
79
+ ++heap_size,
80
+ heap_vals,
81
+ shard_ids,
82
+ D_in[stride * s],
83
+ s);
84
+ }
85
+ }
86
+
87
+ distance_t* D = distances + i * k;
88
+ idx_t* I = labels + i * k;
89
+
90
+ for (int j = 0; j < k; j++) {
91
+ if (heap_size == 0) {
92
+ I[j] = -1;
93
+ D[j] = C::neutral();
94
+ } else {
95
+ // pop best element
96
+ int s = shard_ids[0];
97
+ int& p = pointer[s];
98
+ D[j] = heap_vals[0];
99
+ I[j] = I_in[stride * s + p] + translations[s];
100
+
101
+ heap_pop<C>(heap_size--, heap_vals, shard_ids);
102
+ p++;
103
+ if (p < k && I_in[stride * s + p] >= 0) {
104
+ heap_push<C>(
105
+ ++heap_size,
106
+ heap_vals,
107
+ shard_ids,
108
+ D_in[stride * s + p],
109
+ s);
110
+ }
111
+ }
112
+ }
102
113
  }
103
- }
104
114
  }
105
- }
106
115
  }
107
116
 
108
117
  } // anonymous namespace
109
118
 
110
119
  template <typename IndexT>
111
- IndexShardsTemplate<IndexT>::IndexShardsTemplate(idx_t d,
112
- bool threaded,
113
- bool successive_ids)
114
- : ThreadedIndex<IndexT>(d, threaded),
115
- successive_ids(successive_ids) {
116
- }
120
+ IndexShardsTemplate<IndexT>::IndexShardsTemplate(
121
+ idx_t d,
122
+ bool threaded,
123
+ bool successive_ids)
124
+ : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {}
117
125
 
118
126
  template <typename IndexT>
119
- IndexShardsTemplate<IndexT>::IndexShardsTemplate(int d,
120
- bool threaded,
121
- bool successive_ids)
122
- : ThreadedIndex<IndexT>(d, threaded),
123
- successive_ids(successive_ids) {
124
- }
127
+ IndexShardsTemplate<IndexT>::IndexShardsTemplate(
128
+ int d,
129
+ bool threaded,
130
+ bool successive_ids)
131
+ : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {}
125
132
 
126
133
  template <typename IndexT>
127
- IndexShardsTemplate<IndexT>::IndexShardsTemplate(bool threaded,
128
- bool successive_ids)
129
- : ThreadedIndex<IndexT>(threaded),
130
- successive_ids(successive_ids) {
131
- }
134
+ IndexShardsTemplate<IndexT>::IndexShardsTemplate(
135
+ bool threaded,
136
+ bool successive_ids)
137
+ : ThreadedIndex<IndexT>(threaded), successive_ids(successive_ids) {}
132
138
 
133
139
  template <typename IndexT>
134
- void
135
- IndexShardsTemplate<IndexT>::onAfterAddIndex(IndexT* index /* unused */) {
136
- syncWithSubIndexes();
140
+ void IndexShardsTemplate<IndexT>::onAfterAddIndex(IndexT* index /* unused */) {
141
+ syncWithSubIndexes();
137
142
  }
138
143
 
139
144
  template <typename IndexT>
140
- void
141
- IndexShardsTemplate<IndexT>::onAfterRemoveIndex(IndexT* index /* unused */) {
142
- syncWithSubIndexes();
145
+ void IndexShardsTemplate<IndexT>::onAfterRemoveIndex(
146
+ IndexT* index /* unused */) {
147
+ syncWithSubIndexes();
143
148
  }
144
149
 
145
150
  // FIXME: assumes that nothing is currently running on the sub-indexes, which is
146
151
  // true with the normal API, but should use the runOnIndex API instead
147
152
  template <typename IndexT>
148
- void
149
- IndexShardsTemplate<IndexT>::syncWithSubIndexes() {
150
- if (!this->count()) {
151
- this->is_trained = false;
152
- this->ntotal = 0;
153
-
154
- return;
155
- }
156
-
157
- auto firstIndex = this->at(0);
158
- this->metric_type = firstIndex->metric_type;
159
- this->is_trained = firstIndex->is_trained;
160
- this->ntotal = firstIndex->ntotal;
161
-
162
- for (int i = 1; i < this->count(); ++i) {
163
- auto index = this->at(i);
164
- FAISS_THROW_IF_NOT(this->metric_type == index->metric_type);
165
- FAISS_THROW_IF_NOT(this->d == index->d);
166
- FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
167
-
168
- this->ntotal += index->ntotal;
169
- }
153
+ void IndexShardsTemplate<IndexT>::syncWithSubIndexes() {
154
+ if (!this->count()) {
155
+ this->is_trained = false;
156
+ this->ntotal = 0;
157
+
158
+ return;
159
+ }
160
+
161
+ auto firstIndex = this->at(0);
162
+ this->metric_type = firstIndex->metric_type;
163
+ this->is_trained = firstIndex->is_trained;
164
+ this->ntotal = firstIndex->ntotal;
165
+
166
+ for (int i = 1; i < this->count(); ++i) {
167
+ auto index = this->at(i);
168
+ FAISS_THROW_IF_NOT(this->metric_type == index->metric_type);
169
+ FAISS_THROW_IF_NOT(this->d == index->d);
170
+ FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
171
+
172
+ this->ntotal += index->ntotal;
173
+ }
170
174
  }
171
175
 
172
176
  // No metric_type for IndexBinary
173
177
  template <>
174
- void
175
- IndexShardsTemplate<IndexBinary>::syncWithSubIndexes() {
176
- if (!this->count()) {
177
- this->is_trained = false;
178
- this->ntotal = 0;
179
-
180
- return;
181
- }
182
-
183
- auto firstIndex = this->at(0);
184
- this->is_trained = firstIndex->is_trained;
185
- this->ntotal = firstIndex->ntotal;
186
-
187
- for (int i = 1; i < this->count(); ++i) {
188
- auto index = this->at(i);
189
- FAISS_THROW_IF_NOT(this->d == index->d);
190
- FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
191
-
192
- this->ntotal += index->ntotal;
193
- }
178
+ void IndexShardsTemplate<IndexBinary>::syncWithSubIndexes() {
179
+ if (!this->count()) {
180
+ this->is_trained = false;
181
+ this->ntotal = 0;
182
+
183
+ return;
184
+ }
185
+
186
+ auto firstIndex = this->at(0);
187
+ this->is_trained = firstIndex->is_trained;
188
+ this->ntotal = firstIndex->ntotal;
189
+
190
+ for (int i = 1; i < this->count(); ++i) {
191
+ auto index = this->at(i);
192
+ FAISS_THROW_IF_NOT(this->d == index->d);
193
+ FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
194
+
195
+ this->ntotal += index->ntotal;
196
+ }
194
197
  }
195
198
 
196
199
  template <typename IndexT>
197
- void
198
- IndexShardsTemplate<IndexT>::train(idx_t n,
199
- const component_t *x) {
200
- auto fn =
201
- [n, x](int no, IndexT *index) {
202
- if (index->verbose) {
203
- printf("begin train shard %d on %" PRId64 " points\n", no, n);
204
- }
205
-
206
- index->train(n, x);
207
-
208
- if (index->verbose) {
209
- printf("end train shard %d\n", no);
210
- }
200
+ void IndexShardsTemplate<IndexT>::train(idx_t n, const component_t* x) {
201
+ auto fn = [n, x](int no, IndexT* index) {
202
+ if (index->verbose) {
203
+ printf("begin train shard %d on %" PRId64 " points\n", no, n);
204
+ }
205
+
206
+ index->train(n, x);
207
+
208
+ if (index->verbose) {
209
+ printf("end train shard %d\n", no);
210
+ }
211
211
  };
212
212
 
213
- this->runOnIndex(fn);
214
- syncWithSubIndexes();
213
+ this->runOnIndex(fn);
214
+ syncWithSubIndexes();
215
215
  }
216
216
 
217
217
  template <typename IndexT>
218
- void
219
- IndexShardsTemplate<IndexT>::add(idx_t n,
220
- const component_t *x) {
221
- add_with_ids(n, x, nullptr);
218
+ void IndexShardsTemplate<IndexT>::add(idx_t n, const component_t* x) {
219
+ add_with_ids(n, x, nullptr);
222
220
  }
223
221
 
224
222
  template <typename IndexT>
225
- void
226
- IndexShardsTemplate<IndexT>::add_with_ids(idx_t n,
227
- const component_t * x,
228
- const idx_t *xids) {
229
-
230
- FAISS_THROW_IF_NOT_MSG(!(successive_ids && xids),
231
- "It makes no sense to pass in ids and "
232
- "request them to be shifted");
233
-
234
- if (successive_ids) {
235
- FAISS_THROW_IF_NOT_MSG(!xids,
236
- "It makes no sense to pass in ids and "
237
- "request them to be shifted");
238
- FAISS_THROW_IF_NOT_MSG(this->ntotal == 0,
239
- "when adding to IndexShards with sucessive_ids, "
240
- "only add() in a single pass is supported");
241
- }
242
-
243
- idx_t nshard = this->count();
244
- const idx_t *ids = xids;
245
-
246
- std::vector<idx_t> aids;
247
-
248
- if (!ids && !successive_ids) {
249
- aids.resize(n);
250
-
251
- for (idx_t i = 0; i < n; i++) {
252
- aids[i] = this->ntotal + i;
223
+ void IndexShardsTemplate<IndexT>::add_with_ids(
224
+ idx_t n,
225
+ const component_t* x,
226
+ const idx_t* xids) {
227
+ FAISS_THROW_IF_NOT_MSG(
228
+ !(successive_ids && xids),
229
+ "It makes no sense to pass in ids and "
230
+ "request them to be shifted");
231
+
232
+ if (successive_ids) {
233
+ FAISS_THROW_IF_NOT_MSG(
234
+ !xids,
235
+ "It makes no sense to pass in ids and "
236
+ "request them to be shifted");
237
+ FAISS_THROW_IF_NOT_MSG(
238
+ this->ntotal == 0,
239
+ "when adding to IndexShards with sucessive_ids, "
240
+ "only add() in a single pass is supported");
253
241
  }
254
242
 
255
- ids = aids.data();
256
- }
243
+ idx_t nshard = this->count();
244
+ const idx_t* ids = xids;
245
+
246
+ std::vector<idx_t> aids;
247
+
248
+ if (!ids && !successive_ids) {
249
+ aids.resize(n);
250
+
251
+ for (idx_t i = 0; i < n; i++) {
252
+ aids[i] = this->ntotal + i;
253
+ }
254
+
255
+ ids = aids.data();
256
+ }
257
257
 
258
- size_t components_per_vec =
259
- sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d;
258
+ size_t components_per_vec =
259
+ sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d;
260
260
 
261
- auto fn =
262
- [n, ids, x, nshard, components_per_vec](int no, IndexT *index) {
263
- idx_t i0 = (idx_t) no * n / nshard;
264
- idx_t i1 = ((idx_t) no + 1) * n / nshard;
265
- auto x0 = x + i0 * components_per_vec;
261
+ auto fn = [n, ids, x, nshard, components_per_vec](int no, IndexT* index) {
262
+ idx_t i0 = (idx_t)no * n / nshard;
263
+ idx_t i1 = ((idx_t)no + 1) * n / nshard;
264
+ auto x0 = x + i0 * components_per_vec;
266
265
 
267
- if (index->verbose) {
268
- printf ("begin add shard %d on %" PRId64 " points\n", no, n);
269
- }
266
+ if (index->verbose) {
267
+ printf("begin add shard %d on %" PRId64 " points\n", no, n);
268
+ }
270
269
 
271
- if (ids) {
272
- index->add_with_ids (i1 - i0, x0, ids + i0);
273
- } else {
274
- index->add (i1 - i0, x0);
275
- }
270
+ if (ids) {
271
+ index->add_with_ids(i1 - i0, x0, ids + i0);
272
+ } else {
273
+ index->add(i1 - i0, x0);
274
+ }
276
275
 
277
- if (index->verbose) {
278
- printf ("end add shard %d on %" PRId64 " points\n", no, i1 - i0);
279
- }
276
+ if (index->verbose) {
277
+ printf("end add shard %d on %" PRId64 " points\n", no, i1 - i0);
278
+ }
280
279
  };
281
280
 
282
- this->runOnIndex(fn);
283
- syncWithSubIndexes();
281
+ this->runOnIndex(fn);
282
+ syncWithSubIndexes();
284
283
  }
285
284
 
286
285
  template <typename IndexT>
287
- void
288
- IndexShardsTemplate<IndexT>::search(idx_t n,
289
- const component_t *x,
290
- idx_t k,
291
- distance_t *distances,
292
- idx_t *labels) const {
293
- long nshard = this->count();
294
-
295
- std::vector<distance_t> all_distances(nshard * k * n);
296
- std::vector<idx_t> all_labels(nshard * k * n);
297
-
298
- auto fn =
299
- [n, k, x, &all_distances, &all_labels](int no, const IndexT *index) {
300
- if (index->verbose) {
301
- printf ("begin query shard %d on %" PRId64 " points\n", no, n);
302
- }
303
-
304
- index->search (n, x, k,
305
- all_distances.data() + no * k * n,
306
- all_labels.data() + no * k * n);
307
-
308
- if (index->verbose) {
309
- printf ("end query shard %d\n", no);
310
- }
286
+ void IndexShardsTemplate<IndexT>::search(
287
+ idx_t n,
288
+ const component_t* x,
289
+ idx_t k,
290
+ distance_t* distances,
291
+ idx_t* labels) const {
292
+ FAISS_THROW_IF_NOT(k > 0);
293
+
294
+ long nshard = this->count();
295
+
296
+ std::vector<distance_t> all_distances(nshard * k * n);
297
+ std::vector<idx_t> all_labels(nshard * k * n);
298
+
299
+ auto fn = [n, k, x, &all_distances, &all_labels](
300
+ int no, const IndexT* index) {
301
+ if (index->verbose) {
302
+ printf("begin query shard %d on %" PRId64 " points\n", no, n);
303
+ }
304
+
305
+ index->search(
306
+ n,
307
+ x,
308
+ k,
309
+ all_distances.data() + no * k * n,
310
+ all_labels.data() + no * k * n);
311
+
312
+ if (index->verbose) {
313
+ printf("end query shard %d\n", no);
314
+ }
311
315
  };
312
316
 
313
- this->runOnIndex(fn);
317
+ this->runOnIndex(fn);
318
+
319
+ std::vector<long> translations(nshard, 0);
314
320
 
315
- std::vector<long> translations(nshard, 0);
321
+ // Because we just called runOnIndex above, it is safe to access the
322
+ // sub-index ntotal here
323
+ if (successive_ids) {
324
+ translations[0] = 0;
316
325
 
317
- // Because we just called runOnIndex above, it is safe to access the sub-index
318
- // ntotal here
319
- if (successive_ids) {
320
- translations[0] = 0;
326
+ for (int s = 0; s + 1 < nshard; s++) {
327
+ translations[s + 1] = translations[s] + this->at(s)->ntotal;
328
+ }
329
+ }
321
330
 
322
- for (int s = 0; s + 1 < nshard; s++) {
323
- translations[s + 1] = translations[s] + this->at(s)->ntotal;
331
+ if (this->metric_type == METRIC_L2) {
332
+ merge_tables<IndexT, CMin<distance_t, int>>(
333
+ n,
334
+ k,
335
+ nshard,
336
+ distances,
337
+ labels,
338
+ all_distances,
339
+ all_labels,
340
+ translations);
341
+ } else {
342
+ merge_tables<IndexT, CMax<distance_t, int>>(
343
+ n,
344
+ k,
345
+ nshard,
346
+ distances,
347
+ labels,
348
+ all_distances,
349
+ all_labels,
350
+ translations);
324
351
  }
325
- }
326
-
327
- if (this->metric_type == METRIC_L2) {
328
- merge_tables<IndexT, CMin<distance_t, int>>(
329
- n, k, nshard, distances, labels,
330
- all_distances, all_labels, translations);
331
- } else {
332
- merge_tables<IndexT, CMax<distance_t, int>>(
333
- n, k, nshard, distances, labels,
334
- all_distances, all_labels, translations);
335
- }
336
352
  }
337
353
 
338
354
  // explicit instanciations