faiss 0.1.7 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -7
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +8 -2
  6. data/ext/faiss/index.cpp +102 -69
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +0 -5
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +26 -12
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -9,332 +9,340 @@
9
9
 
10
10
  #include <faiss/IndexIVFFlat.h>
11
11
 
12
+ #include <omp.h>
13
+
12
14
  #include <cinttypes>
13
15
  #include <cstdio>
14
16
 
15
17
  #include <faiss/IndexFlat.h>
16
18
 
19
+ #include <faiss/impl/AuxIndexStructures.h>
20
+ #include <faiss/impl/FaissAssert.h>
17
21
  #include <faiss/utils/distances.h>
18
22
  #include <faiss/utils/utils.h>
19
- #include <faiss/impl/FaissAssert.h>
20
- #include <faiss/impl/AuxIndexStructures.h>
21
-
22
23
 
23
24
  namespace faiss {
24
25
 
25
-
26
26
  /*****************************************
27
27
  * IndexIVFFlat implementation
28
28
  ******************************************/
29
29
 
30
- IndexIVFFlat::IndexIVFFlat (Index * quantizer,
31
- size_t d, size_t nlist, MetricType metric):
32
- IndexIVF (quantizer, d, nlist, sizeof(float) * d, metric)
33
- {
30
+ IndexIVFFlat::IndexIVFFlat(
31
+ Index* quantizer,
32
+ size_t d,
33
+ size_t nlist,
34
+ MetricType metric)
35
+ : IndexIVF(quantizer, d, nlist, sizeof(float) * d, metric) {
34
36
  code_size = sizeof(float) * d;
35
37
  }
36
38
 
39
+ void IndexIVFFlat::add_core(
40
+ idx_t n,
41
+ const float* x,
42
+ const int64_t* xids,
43
+ const int64_t* coarse_idx)
37
44
 
38
- void IndexIVFFlat::add_with_ids (idx_t n, const float * x, const idx_t *xids)
39
45
  {
40
- add_core (n, x, xids, nullptr);
41
- }
42
-
43
- void IndexIVFFlat::add_core (idx_t n, const float * x, const int64_t *xids,
44
- const int64_t *precomputed_idx)
46
+ FAISS_THROW_IF_NOT(is_trained);
47
+ FAISS_THROW_IF_NOT(coarse_idx);
48
+ assert(invlists);
49
+ direct_map.check_can_add(xids);
45
50
 
46
- {
47
- FAISS_THROW_IF_NOT (is_trained);
48
- assert (invlists);
49
- direct_map.check_can_add (xids);
50
- const int64_t * idx;
51
- ScopeDeleter<int64_t> del;
52
-
53
- if (precomputed_idx) {
54
- idx = precomputed_idx;
55
- } else {
56
- int64_t * idx0 = new int64_t [n];
57
- del.set (idx0);
58
- quantizer->assign (n, x, idx0);
59
- idx = idx0;
60
- }
61
51
  int64_t n_add = 0;
62
- for (size_t i = 0; i < n; i++) {
63
- idx_t id = xids ? xids[i] : ntotal + i;
64
- idx_t list_no = idx [i];
65
- size_t offset;
66
-
67
- if (list_no >= 0) {
68
- const float *xi = x + i * d;
69
- offset = invlists->add_entry (
70
- list_no, id, (const uint8_t*) xi);
71
- n_add++;
72
- } else {
73
- offset = 0;
52
+
53
+ DirectMapAdd dm_adder(direct_map, n, xids);
54
+
55
+ #pragma omp parallel reduction(+ : n_add)
56
+ {
57
+ int nt = omp_get_num_threads();
58
+ int rank = omp_get_thread_num();
59
+
60
+ // each thread takes care of a subset of lists
61
+ for (size_t i = 0; i < n; i++) {
62
+ idx_t list_no = coarse_idx[i];
63
+
64
+ if (list_no >= 0 && list_no % nt == rank) {
65
+ idx_t id = xids ? xids[i] : ntotal + i;
66
+ const float* xi = x + i * d;
67
+ size_t offset =
68
+ invlists->add_entry(list_no, id, (const uint8_t*)xi);
69
+ dm_adder.add(i, list_no, offset);
70
+ n_add++;
71
+ } else if (rank == 0 && list_no == -1) {
72
+ dm_adder.add(i, -1, 0);
73
+ }
74
74
  }
75
- direct_map.add_single_id (id, list_no, offset);
76
75
  }
77
76
 
78
77
  if (verbose) {
79
- printf("IndexIVFFlat::add_core: added %" PRId64 " / %" PRId64 " vectors\n",
80
- n_add, n);
78
+ printf("IndexIVFFlat::add_core: added %" PRId64 " / %" PRId64
79
+ " vectors\n",
80
+ n_add,
81
+ n);
81
82
  }
82
83
  ntotal += n;
83
84
  }
84
85
 
85
- void IndexIVFFlat::encode_vectors(idx_t n, const float* x,
86
- const idx_t * list_nos,
87
- uint8_t * codes,
88
- bool include_listnos) const
89
- {
86
+ void IndexIVFFlat::encode_vectors(
87
+ idx_t n,
88
+ const float* x,
89
+ const idx_t* list_nos,
90
+ uint8_t* codes,
91
+ bool include_listnos) const {
90
92
  if (!include_listnos) {
91
- memcpy (codes, x, code_size * n);
93
+ memcpy(codes, x, code_size * n);
92
94
  } else {
93
- size_t coarse_size = coarse_code_size ();
95
+ size_t coarse_size = coarse_code_size();
94
96
  for (size_t i = 0; i < n; i++) {
95
- int64_t list_no = list_nos [i];
96
- uint8_t *code = codes + i * (code_size + coarse_size);
97
- const float *xi = x + i * d;
97
+ int64_t list_no = list_nos[i];
98
+ uint8_t* code = codes + i * (code_size + coarse_size);
99
+ const float* xi = x + i * d;
98
100
  if (list_no >= 0) {
99
- encode_listno (list_no, code);
100
- memcpy (code + coarse_size, xi, code_size);
101
+ encode_listno(list_no, code);
102
+ memcpy(code + coarse_size, xi, code_size);
101
103
  } else {
102
- memset (code, 0, code_size + coarse_size);
104
+ memset(code, 0, code_size + coarse_size);
103
105
  }
104
-
105
106
  }
106
107
  }
107
108
  }
108
109
 
109
- void IndexIVFFlat::sa_decode (idx_t n, const uint8_t *bytes,
110
- float *x) const
111
- {
112
- size_t coarse_size = coarse_code_size ();
110
+ void IndexIVFFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
111
+ size_t coarse_size = coarse_code_size();
113
112
  for (size_t i = 0; i < n; i++) {
114
- const uint8_t *code = bytes + i * (code_size + coarse_size);
115
- float *xi = x + i * d;
116
- memcpy (xi, code + coarse_size, code_size);
113
+ const uint8_t* code = bytes + i * (code_size + coarse_size);
114
+ float* xi = x + i * d;
115
+ memcpy(xi, code + coarse_size, code_size);
117
116
  }
118
117
  }
119
118
 
120
-
121
119
  namespace {
122
120
 
123
-
124
- template<MetricType metric, class C>
125
- struct IVFFlatScanner: InvertedListScanner {
121
+ template <MetricType metric, class C>
122
+ struct IVFFlatScanner : InvertedListScanner {
126
123
  size_t d;
127
124
  bool store_pairs;
128
125
 
129
- IVFFlatScanner(size_t d, bool store_pairs):
130
- d(d), store_pairs(store_pairs) {}
126
+ IVFFlatScanner(size_t d, bool store_pairs)
127
+ : d(d), store_pairs(store_pairs) {}
131
128
 
132
- const float *xi;
133
- void set_query (const float *query) override {
129
+ const float* xi;
130
+ void set_query(const float* query) override {
134
131
  this->xi = query;
135
132
  }
136
133
 
137
134
  idx_t list_no;
138
- void set_list (idx_t list_no, float /* coarse_dis */) override {
135
+ void set_list(idx_t list_no, float /* coarse_dis */) override {
139
136
  this->list_no = list_no;
140
137
  }
141
138
 
142
- float distance_to_code (const uint8_t *code) const override {
143
- const float *yj = (float*)code;
144
- float dis = metric == METRIC_INNER_PRODUCT ?
145
- fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d);
139
+ float distance_to_code(const uint8_t* code) const override {
140
+ const float* yj = (float*)code;
141
+ float dis = metric == METRIC_INNER_PRODUCT
142
+ ? fvec_inner_product(xi, yj, d)
143
+ : fvec_L2sqr(xi, yj, d);
146
144
  return dis;
147
145
  }
148
146
 
149
- size_t scan_codes (size_t list_size,
150
- const uint8_t *codes,
151
- const idx_t *ids,
152
- float *simi, idx_t *idxi,
153
- size_t k) const override
154
- {
155
- const float *list_vecs = (const float*)codes;
147
+ size_t scan_codes(
148
+ size_t list_size,
149
+ const uint8_t* codes,
150
+ const idx_t* ids,
151
+ float* simi,
152
+ idx_t* idxi,
153
+ size_t k) const override {
154
+ const float* list_vecs = (const float*)codes;
156
155
  size_t nup = 0;
157
156
  for (size_t j = 0; j < list_size; j++) {
158
- const float * yj = list_vecs + d * j;
159
- float dis = metric == METRIC_INNER_PRODUCT ?
160
- fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d);
161
- if (C::cmp (simi[0], dis)) {
162
- int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
163
- heap_replace_top<C> (k, simi, idxi, dis, id);
157
+ const float* yj = list_vecs + d * j;
158
+ float dis = metric == METRIC_INNER_PRODUCT
159
+ ? fvec_inner_product(xi, yj, d)
160
+ : fvec_L2sqr(xi, yj, d);
161
+ if (C::cmp(simi[0], dis)) {
162
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
163
+ heap_replace_top<C>(k, simi, idxi, dis, id);
164
164
  nup++;
165
165
  }
166
166
  }
167
167
  return nup;
168
168
  }
169
169
 
170
- void scan_codes_range (size_t list_size,
171
- const uint8_t *codes,
172
- const idx_t *ids,
173
- float radius,
174
- RangeQueryResult & res) const override
175
- {
176
- const float *list_vecs = (const float*)codes;
170
+ void scan_codes_range(
171
+ size_t list_size,
172
+ const uint8_t* codes,
173
+ const idx_t* ids,
174
+ float radius,
175
+ RangeQueryResult& res) const override {
176
+ const float* list_vecs = (const float*)codes;
177
177
  for (size_t j = 0; j < list_size; j++) {
178
- const float * yj = list_vecs + d * j;
179
- float dis = metric == METRIC_INNER_PRODUCT ?
180
- fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d);
181
- if (C::cmp (radius, dis)) {
182
- int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
183
- res.add (dis, id);
178
+ const float* yj = list_vecs + d * j;
179
+ float dis = metric == METRIC_INNER_PRODUCT
180
+ ? fvec_inner_product(xi, yj, d)
181
+ : fvec_L2sqr(xi, yj, d);
182
+ if (C::cmp(radius, dis)) {
183
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
184
+ res.add(dis, id);
184
185
  }
185
186
  }
186
187
  }
187
-
188
-
189
188
  };
190
189
 
191
-
192
190
  } // anonymous namespace
193
191
 
194
-
195
-
196
- InvertedListScanner* IndexIVFFlat::get_InvertedListScanner
197
- (bool store_pairs) const
198
- {
192
+ InvertedListScanner* IndexIVFFlat::get_InvertedListScanner(
193
+ bool store_pairs) const {
199
194
  if (metric_type == METRIC_INNER_PRODUCT) {
200
- return new IVFFlatScanner<
201
- METRIC_INNER_PRODUCT, CMin<float, int64_t> > (d, store_pairs);
195
+ return new IVFFlatScanner<METRIC_INNER_PRODUCT, CMin<float, int64_t>>(
196
+ d, store_pairs);
202
197
  } else if (metric_type == METRIC_L2) {
203
- return new IVFFlatScanner<
204
- METRIC_L2, CMax<float, int64_t> >(d, store_pairs);
198
+ return new IVFFlatScanner<METRIC_L2, CMax<float, int64_t>>(
199
+ d, store_pairs);
205
200
  } else {
206
201
  FAISS_THROW_MSG("metric type not supported");
207
202
  }
208
203
  return nullptr;
209
204
  }
210
205
 
211
-
212
-
213
-
214
- void IndexIVFFlat::reconstruct_from_offset (int64_t list_no, int64_t offset,
215
- float* recons) const
216
- {
217
- memcpy (recons, invlists->get_single_code (list_no, offset), code_size);
206
+ void IndexIVFFlat::reconstruct_from_offset(
207
+ int64_t list_no,
208
+ int64_t offset,
209
+ float* recons) const {
210
+ memcpy(recons, invlists->get_single_code(list_no, offset), code_size);
218
211
  }
219
212
 
220
213
  /*****************************************
221
214
  * IndexIVFFlatDedup implementation
222
215
  ******************************************/
223
216
 
224
- IndexIVFFlatDedup::IndexIVFFlatDedup (
225
- Index * quantizer, size_t d, size_t nlist_,
226
- MetricType metric_type):
227
- IndexIVFFlat (quantizer, d, nlist_, metric_type)
228
- {}
229
-
217
+ IndexIVFFlatDedup::IndexIVFFlatDedup(
218
+ Index* quantizer,
219
+ size_t d,
220
+ size_t nlist_,
221
+ MetricType metric_type)
222
+ : IndexIVFFlat(quantizer, d, nlist_, metric_type) {}
230
223
 
231
- void IndexIVFFlatDedup::train(idx_t n, const float* x)
232
- {
224
+ void IndexIVFFlatDedup::train(idx_t n, const float* x) {
233
225
  std::unordered_map<uint64_t, idx_t> map;
234
- float * x2 = new float [n * d];
235
- ScopeDeleter<float> del (x2);
226
+ float* x2 = new float[n * d];
227
+ ScopeDeleter<float> del(x2);
236
228
 
237
229
  int64_t n2 = 0;
238
230
  for (int64_t i = 0; i < n; i++) {
239
- uint64_t hash = hash_bytes((uint8_t *)(x + i * d), code_size);
231
+ uint64_t hash = hash_bytes((uint8_t*)(x + i * d), code_size);
240
232
  if (map.count(hash) &&
241
- !memcmp (x2 + map[hash] * d, x + i * d, code_size)) {
233
+ !memcmp(x2 + map[hash] * d, x + i * d, code_size)) {
242
234
  // is duplicate, skip
243
235
  } else {
244
- map [hash] = n2;
245
- memcpy (x2 + n2 * d, x + i * d, code_size);
246
- n2 ++;
236
+ map[hash] = n2;
237
+ memcpy(x2 + n2 * d, x + i * d, code_size);
238
+ n2++;
247
239
  }
248
240
  }
249
241
  if (verbose) {
250
- printf ("IndexIVFFlatDedup::train: train on %" PRId64 " points after dedup "
251
- "(was %" PRId64 " points)\n", n2, n);
242
+ printf("IndexIVFFlatDedup::train: train on %" PRId64
243
+ " points after dedup "
244
+ "(was %" PRId64 " points)\n",
245
+ n2,
246
+ n);
252
247
  }
253
- IndexIVFFlat::train (n2, x2);
248
+ IndexIVFFlat::train(n2, x2);
254
249
  }
255
250
 
251
+ void IndexIVFFlatDedup::add_with_ids(
252
+ idx_t na,
253
+ const float* x,
254
+ const idx_t* xids) {
255
+ FAISS_THROW_IF_NOT(is_trained);
256
+ assert(invlists);
257
+ FAISS_THROW_IF_NOT_MSG(
258
+ direct_map.no(), "IVFFlatDedup not implemented with direct_map");
259
+ int64_t* idx = new int64_t[na];
260
+ ScopeDeleter<int64_t> del(idx);
261
+ quantizer->assign(na, x, idx);
256
262
 
263
+ int64_t n_add = 0, n_dup = 0;
257
264
 
258
- void IndexIVFFlatDedup::add_with_ids(
259
- idx_t na, const float* x, const idx_t* xids)
260
- {
265
+ #pragma omp parallel reduction(+ : n_add, n_dup)
266
+ {
267
+ int nt = omp_get_num_threads();
268
+ int rank = omp_get_thread_num();
261
269
 
262
- FAISS_THROW_IF_NOT (is_trained);
263
- assert (invlists);
264
- FAISS_THROW_IF_NOT_MSG (direct_map.no(),
265
- "IVFFlatDedup not implemented with direct_map");
266
- int64_t * idx = new int64_t [na];
267
- ScopeDeleter<int64_t> del (idx);
268
- quantizer->assign (na, x, idx);
270
+ // each thread takes care of a subset of lists
271
+ for (size_t i = 0; i < na; i++) {
272
+ int64_t list_no = idx[i];
269
273
 
270
- int64_t n_add = 0, n_dup = 0;
271
- // TODO make a omp loop with this
272
- for (size_t i = 0; i < na; i++) {
273
- idx_t id = xids ? xids[i] : ntotal + i;
274
- int64_t list_no = idx [i];
274
+ if (list_no < 0 || list_no % nt != rank) {
275
+ continue;
276
+ }
275
277
 
276
- if (list_no < 0) {
277
- continue;
278
- }
279
- const float *xi = x + i * d;
278
+ idx_t id = xids ? xids[i] : ntotal + i;
279
+ const float* xi = x + i * d;
280
280
 
281
- // search if there is already an entry with that id
282
- InvertedLists::ScopedCodes codes (invlists, list_no);
281
+ // search if there is already an entry with that id
282
+ InvertedLists::ScopedCodes codes(invlists, list_no);
283
283
 
284
- int64_t n = invlists->list_size (list_no);
285
- int64_t offset = -1;
286
- for (int64_t o = 0; o < n; o++) {
287
- if (!memcmp (codes.get() + o * code_size,
288
- xi, code_size)) {
289
- offset = o;
290
- break;
284
+ int64_t n = invlists->list_size(list_no);
285
+ int64_t offset = -1;
286
+ for (int64_t o = 0; o < n; o++) {
287
+ if (!memcmp(codes.get() + o * code_size, xi, code_size)) {
288
+ offset = o;
289
+ break;
290
+ }
291
291
  }
292
- }
293
292
 
294
- if (offset == -1) { // not found
295
- invlists->add_entry (list_no, id, (const uint8_t*) xi);
296
- } else {
297
- // mark equivalence
298
- idx_t id2 = invlists->get_single_id (list_no, offset);
299
- std::pair<idx_t, idx_t> pair (id2, id);
300
- instances.insert (pair);
301
- n_dup ++;
293
+ if (offset == -1) { // not found
294
+ invlists->add_entry(list_no, id, (const uint8_t*)xi);
295
+ } else {
296
+ // mark equivalence
297
+ idx_t id2 = invlists->get_single_id(list_no, offset);
298
+ std::pair<idx_t, idx_t> pair(id2, id);
299
+
300
+ #pragma omp critical
301
+ // executed by one thread at a time
302
+ instances.insert(pair);
303
+
304
+ n_dup++;
305
+ }
306
+ n_add++;
302
307
  }
303
- n_add++;
304
308
  }
305
309
  if (verbose) {
306
- printf("IndexIVFFlat::add_with_ids: added %" PRId64 " / %" PRId64 " vectors"
310
+ printf("IndexIVFFlat::add_with_ids: added %" PRId64 " / %" PRId64
311
+ " vectors"
307
312
  " (out of which %" PRId64 " are duplicates)\n",
308
- n_add, na, n_dup);
313
+ n_add,
314
+ na,
315
+ n_dup);
309
316
  }
310
317
  ntotal += n_add;
311
318
  }
312
319
 
313
- void IndexIVFFlatDedup::search_preassigned (
314
- idx_t n, const float *x, idx_t k,
315
- const idx_t *assign,
316
- const float *centroid_dis,
317
- float *distances, idx_t *labels,
318
- bool store_pairs,
319
- const IVFSearchParameters *params,
320
- IndexIVFStats *stats) const
321
- {
322
- FAISS_THROW_IF_NOT_MSG (
323
- !store_pairs, "store_pairs not supported in IVFDedup");
324
-
325
- IndexIVFFlat::search_preassigned (n, x, k, assign, centroid_dis,
326
- distances, labels, false,
327
- params);
328
-
329
- std::vector <idx_t> labels2 (k);
330
- std::vector <float> dis2 (k);
320
+ void IndexIVFFlatDedup::search_preassigned(
321
+ idx_t n,
322
+ const float* x,
323
+ idx_t k,
324
+ const idx_t* assign,
325
+ const float* centroid_dis,
326
+ float* distances,
327
+ idx_t* labels,
328
+ bool store_pairs,
329
+ const IVFSearchParameters* params,
330
+ IndexIVFStats* stats) const {
331
+ FAISS_THROW_IF_NOT_MSG(
332
+ !store_pairs, "store_pairs not supported in IVFDedup");
333
+
334
+ IndexIVFFlat::search_preassigned(
335
+ n, x, k, assign, centroid_dis, distances, labels, false, params);
336
+
337
+ std::vector<idx_t> labels2(k);
338
+ std::vector<float> dis2(k);
331
339
 
332
340
  for (int64_t i = 0; i < n; i++) {
333
- idx_t *labels1 = labels + i * k;
334
- float *dis1 = distances + i * k;
341
+ idx_t* labels1 = labels + i * k;
342
+ float* dis1 = distances + i * k;
335
343
  int64_t j = 0;
336
344
  for (; j < k; j++) {
337
- if (instances.find (labels1[j]) != instances.end ()) {
345
+ if (instances.find(labels1[j]) != instances.end()) {
338
346
  // a duplicate: special handling
339
347
  break;
340
348
  }
@@ -344,11 +352,11 @@ void IndexIVFFlatDedup::search_preassigned (
344
352
  int64_t j0 = j;
345
353
  int64_t rp = j;
346
354
  while (j < k) {
347
- auto range = instances.equal_range (labels1[rp]);
355
+ auto range = instances.equal_range(labels1[rp]);
348
356
  float dis = dis1[rp];
349
357
  labels2[j] = labels1[rp];
350
358
  dis2[j] = dis;
351
- j ++;
359
+ j++;
352
360
  for (auto it = range.first; j < k && it != range.second; ++it) {
353
361
  labels2[j] = it->second;
354
362
  dis2[j] = dis;
@@ -356,21 +364,18 @@ void IndexIVFFlatDedup::search_preassigned (
356
364
  }
357
365
  rp++;
358
366
  }
359
- memcpy (labels1 + j0, labels2.data() + j0,
360
- sizeof(labels1[0]) * (k - j0));
361
- memcpy (dis1 + j0, dis2.data() + j0,
362
- sizeof(dis2[0]) * (k - j0));
367
+ memcpy(labels1 + j0,
368
+ labels2.data() + j0,
369
+ sizeof(labels1[0]) * (k - j0));
370
+ memcpy(dis1 + j0, dis2.data() + j0, sizeof(dis2[0]) * (k - j0));
363
371
  }
364
372
  }
365
-
366
373
  }
367
374
 
368
-
369
- size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel)
370
- {
375
+ size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel) {
371
376
  std::unordered_map<idx_t, idx_t> replace;
372
- std::vector<std::pair<idx_t, idx_t> > toadd;
373
- for (auto it = instances.begin(); it != instances.end(); ) {
377
+ std::vector<std::pair<idx_t, idx_t>> toadd;
378
+ for (auto it = instances.begin(); it != instances.end();) {
374
379
  if (sel.is_member(it->first)) {
375
380
  // then we erase this entry
376
381
  if (!sel.is_member(it->second)) {
@@ -378,8 +383,8 @@ size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel)
378
383
  if (replace.count(it->first) == 0) {
379
384
  replace[it->first] = it->second;
380
385
  } else { // remember we should add an element
381
- std::pair<idx_t, idx_t> new_entry (
382
- replace[it->first], it->second);
386
+ std::pair<idx_t, idx_t> new_entry(
387
+ replace[it->first], it->second);
383
388
  toadd.push_back(new_entry);
384
389
  }
385
390
  }
@@ -393,32 +398,34 @@ size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel)
393
398
  }
394
399
  }
395
400
 
396
- instances.insert (toadd.begin(), toadd.end());
401
+ instances.insert(toadd.begin(), toadd.end());
397
402
 
398
403
  // mostly copied from IndexIVF.cpp
399
404
 
400
- FAISS_THROW_IF_NOT_MSG (direct_map.no(),
401
- "direct map remove not implemented");
405
+ FAISS_THROW_IF_NOT_MSG(
406
+ direct_map.no(), "direct map remove not implemented");
402
407
 
403
408
  std::vector<int64_t> toremove(nlist);
404
409
 
405
410
  #pragma omp parallel for
406
411
  for (int64_t i = 0; i < nlist; i++) {
407
- int64_t l0 = invlists->list_size (i), l = l0, j = 0;
408
- InvertedLists::ScopedIds idsi (invlists, i);
412
+ int64_t l0 = invlists->list_size(i), l = l0, j = 0;
413
+ InvertedLists::ScopedIds idsi(invlists, i);
409
414
  while (j < l) {
410
- if (sel.is_member (idsi[j])) {
415
+ if (sel.is_member(idsi[j])) {
411
416
  if (replace.count(idsi[j]) == 0) {
412
417
  l--;
413
- invlists->update_entry (
414
- i, j,
415
- invlists->get_single_id (i, l),
416
- InvertedLists::ScopedCodes (invlists, i, l).get());
418
+ invlists->update_entry(
419
+ i,
420
+ j,
421
+ invlists->get_single_id(i, l),
422
+ InvertedLists::ScopedCodes(invlists, i, l).get());
417
423
  } else {
418
- invlists->update_entry (
419
- i, j,
420
- replace[idsi[j]],
421
- InvertedLists::ScopedCodes (invlists, i, j).get());
424
+ invlists->update_entry(
425
+ i,
426
+ j,
427
+ replace[idsi[j]],
428
+ InvertedLists::ScopedCodes(invlists, i, j).get());
422
429
  j++;
423
430
  }
424
431
  } else {
@@ -432,37 +439,28 @@ size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel)
432
439
  for (int64_t i = 0; i < nlist; i++) {
433
440
  if (toremove[i] > 0) {
434
441
  nremove += toremove[i];
435
- invlists->resize(
436
- i, invlists->list_size(i) - toremove[i]);
442
+ invlists->resize(i, invlists->list_size(i) - toremove[i]);
437
443
  }
438
444
  }
439
445
  ntotal -= nremove;
440
446
  return nremove;
441
447
  }
442
448
 
443
-
444
449
  void IndexIVFFlatDedup::range_search(
445
- idx_t ,
446
- const float* ,
447
- float ,
448
- RangeSearchResult* ) const
449
- {
450
- FAISS_THROW_MSG ("not implemented");
450
+ idx_t,
451
+ const float*,
452
+ float,
453
+ RangeSearchResult*) const {
454
+ FAISS_THROW_MSG("not implemented");
451
455
  }
452
456
 
453
- void IndexIVFFlatDedup::update_vectors (int , const idx_t *, const float *)
454
- {
455
- FAISS_THROW_MSG ("not implemented");
457
+ void IndexIVFFlatDedup::update_vectors(int, const idx_t*, const float*) {
458
+ FAISS_THROW_MSG("not implemented");
456
459
  }
457
460
 
458
-
459
- void IndexIVFFlatDedup::reconstruct_from_offset (
460
- int64_t , int64_t , float* ) const
461
- {
462
- FAISS_THROW_MSG ("not implemented");
461
+ void IndexIVFFlatDedup::reconstruct_from_offset(int64_t, int64_t, float*)
462
+ const {
463
+ FAISS_THROW_MSG("not implemented");
463
464
  }
464
465
 
465
-
466
-
467
-
468
466
  } // namespace faiss