faiss 0.1.1 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +18 -18
  4. data/README.md +1 -1
  5. data/lib/faiss/version.rb +1 -1
  6. data/vendor/faiss/Clustering.cpp +318 -53
  7. data/vendor/faiss/Clustering.h +39 -11
  8. data/vendor/faiss/DirectMap.cpp +267 -0
  9. data/vendor/faiss/DirectMap.h +120 -0
  10. data/vendor/faiss/IVFlib.cpp +24 -4
  11. data/vendor/faiss/IVFlib.h +4 -0
  12. data/vendor/faiss/Index.h +5 -24
  13. data/vendor/faiss/Index2Layer.cpp +0 -1
  14. data/vendor/faiss/IndexBinary.h +7 -3
  15. data/vendor/faiss/IndexBinaryFlat.cpp +5 -0
  16. data/vendor/faiss/IndexBinaryFlat.h +3 -0
  17. data/vendor/faiss/IndexBinaryHash.cpp +492 -0
  18. data/vendor/faiss/IndexBinaryHash.h +116 -0
  19. data/vendor/faiss/IndexBinaryIVF.cpp +160 -107
  20. data/vendor/faiss/IndexBinaryIVF.h +14 -4
  21. data/vendor/faiss/IndexFlat.h +2 -1
  22. data/vendor/faiss/IndexHNSW.cpp +68 -16
  23. data/vendor/faiss/IndexHNSW.h +3 -3
  24. data/vendor/faiss/IndexIVF.cpp +72 -76
  25. data/vendor/faiss/IndexIVF.h +24 -5
  26. data/vendor/faiss/IndexIVFFlat.cpp +19 -54
  27. data/vendor/faiss/IndexIVFFlat.h +1 -11
  28. data/vendor/faiss/IndexIVFPQ.cpp +49 -26
  29. data/vendor/faiss/IndexIVFPQ.h +9 -10
  30. data/vendor/faiss/IndexIVFPQR.cpp +2 -2
  31. data/vendor/faiss/IndexIVFSpectralHash.cpp +2 -2
  32. data/vendor/faiss/IndexLSH.h +4 -1
  33. data/vendor/faiss/IndexPreTransform.cpp +0 -1
  34. data/vendor/faiss/IndexScalarQuantizer.cpp +8 -1
  35. data/vendor/faiss/InvertedLists.cpp +0 -2
  36. data/vendor/faiss/MetaIndexes.cpp +0 -1
  37. data/vendor/faiss/MetricType.h +36 -0
  38. data/vendor/faiss/c_api/Clustering_c.cpp +13 -7
  39. data/vendor/faiss/c_api/Clustering_c.h +11 -5
  40. data/vendor/faiss/c_api/IndexIVF_c.cpp +7 -0
  41. data/vendor/faiss/c_api/IndexIVF_c.h +7 -0
  42. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +21 -0
  43. data/vendor/faiss/c_api/IndexPreTransform_c.h +32 -0
  44. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +185 -0
  45. data/vendor/faiss/gpu/GpuCloner.cpp +4 -0
  46. data/vendor/faiss/gpu/GpuClonerOptions.cpp +1 -1
  47. data/vendor/faiss/gpu/GpuDistance.h +93 -0
  48. data/vendor/faiss/gpu/GpuIndex.h +7 -0
  49. data/vendor/faiss/gpu/GpuIndexFlat.h +0 -10
  50. data/vendor/faiss/gpu/GpuIndexIVF.h +1 -0
  51. data/vendor/faiss/gpu/StandardGpuResources.cpp +8 -0
  52. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +49 -27
  53. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +110 -2
  54. data/vendor/faiss/gpu/utils/DeviceUtils.h +6 -0
  55. data/vendor/faiss/impl/AuxIndexStructures.cpp +17 -0
  56. data/vendor/faiss/impl/AuxIndexStructures.h +14 -3
  57. data/vendor/faiss/impl/HNSW.cpp +0 -1
  58. data/vendor/faiss/impl/PolysemousTraining.h +5 -5
  59. data/vendor/faiss/impl/ProductQuantizer-inl.h +138 -0
  60. data/vendor/faiss/impl/ProductQuantizer.cpp +1 -113
  61. data/vendor/faiss/impl/ProductQuantizer.h +42 -47
  62. data/vendor/faiss/impl/index_read.cpp +103 -7
  63. data/vendor/faiss/impl/index_write.cpp +101 -5
  64. data/vendor/faiss/impl/io.cpp +111 -1
  65. data/vendor/faiss/impl/io.h +38 -0
  66. data/vendor/faiss/index_factory.cpp +0 -1
  67. data/vendor/faiss/tests/test_merge.cpp +0 -1
  68. data/vendor/faiss/tests/test_pq_encoding.cpp +6 -6
  69. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +1 -0
  70. data/vendor/faiss/utils/distances.cpp +4 -5
  71. data/vendor/faiss/utils/distances_simd.cpp +0 -1
  72. data/vendor/faiss/utils/hamming.cpp +85 -3
  73. data/vendor/faiss/utils/hamming.h +20 -0
  74. data/vendor/faiss/utils/utils.cpp +0 -96
  75. data/vendor/faiss/utils/utils.h +0 -15
  76. metadata +11 -3
  77. data/lib/faiss/ext.bundle +0 -0
@@ -46,8 +46,7 @@ struct IndexBinaryIVF : IndexBinary {
46
46
  bool use_heap = true;
47
47
 
48
48
  /// map for direct access to the elements. Enables reconstruct().
49
- bool maintain_direct_map;
50
- std::vector<idx_t> direct_map;
49
+ DirectMap direct_map;
51
50
 
52
51
  IndexBinary *quantizer; ///< quantizer that maps vectors to inverted lists
53
52
  size_t nlist; ///< number of possible key values
@@ -110,8 +109,11 @@ struct IndexBinaryIVF : IndexBinary {
110
109
  bool store_pairs=false) const;
111
110
 
112
111
  /** assign the vectors, then call search_preassign */
113
- virtual void search(idx_t n, const uint8_t *x, idx_t k,
114
- int32_t *distances, idx_t *labels) const override;
112
+ void search(idx_t n, const uint8_t *x, idx_t k,
113
+ int32_t *distances, idx_t *labels) const override;
114
+
115
+ void range_search(idx_t n, const uint8_t *x, int radius,
116
+ RangeSearchResult *result) const override;
115
117
 
116
118
  void reconstruct(idx_t key, uint8_t *recons) const override;
117
119
 
@@ -168,6 +170,8 @@ struct IndexBinaryIVF : IndexBinary {
168
170
  */
169
171
  void make_direct_map(bool new_maintain_direct_map=true);
170
172
 
173
+ void set_direct_map_type (DirectMap::Type type);
174
+
171
175
  void replace_invlists(InvertedLists *il, bool own=false);
172
176
  };
173
177
 
@@ -201,6 +205,12 @@ struct BinaryInvertedListScanner {
201
205
  int32_t *distances, idx_t *labels,
202
206
  size_t k) const = 0;
203
207
 
208
+ virtual void scan_codes_range (size_t n,
209
+ const uint8_t *codes,
210
+ const idx_t *ids,
211
+ int radius,
212
+ RangeQueryResult &result) const = 0;
213
+
204
214
  virtual ~BinaryInvertedListScanner () {}
205
215
 
206
216
  };
@@ -19,6 +19,7 @@ namespace faiss {
19
19
 
20
20
  /** Index that stores the full vectors and performs exhaustive search */
21
21
  struct IndexFlat: Index {
22
+
22
23
  /// database vectors, size ntotal * d
23
24
  std::vector<float> xb;
24
25
 
@@ -144,7 +145,7 @@ struct IndexRefineFlat: Index {
144
145
  };
145
146
 
146
147
 
147
- /// optimized version for 1D "vectors"
148
+ /// optimized version for 1D "vectors".
148
149
  struct IndexFlat1D:IndexFlatL2 {
149
150
  bool continuous_update; ///< is the permutation updated continuously?
150
151
 
@@ -26,7 +26,6 @@
26
26
  #include <stdint.h>
27
27
 
28
28
  #ifdef __SSE__
29
- #include <immintrin.h>
30
29
  #endif
31
30
 
32
31
  #include <faiss/utils/distances.h>
@@ -55,7 +54,6 @@ namespace faiss {
55
54
  using idx_t = Index::idx_t;
56
55
  using MinimaxHeap = HNSW::MinimaxHeap;
57
56
  using storage_idx_t = HNSW::storage_idx_t;
58
- using NodeDistCloser = HNSW::NodeDistCloser;
59
57
  using NodeDistFarther = HNSW::NodeDistFarther;
60
58
 
61
59
  HNSWStats hnsw_stats;
@@ -67,6 +65,50 @@ HNSWStats hnsw_stats;
67
65
  namespace {
68
66
 
69
67
 
68
+ /* Wrap the distance computer into one that negates the
69
+ distances. This makes supporting INNER_PRODUCE search easier */
70
+
71
+ struct NegativeDistanceComputer: DistanceComputer {
72
+
73
+ /// owned by this
74
+ DistanceComputer *basedis;
75
+
76
+ explicit NegativeDistanceComputer(DistanceComputer *basedis):
77
+ basedis(basedis)
78
+ {}
79
+
80
+ void set_query(const float *x) override {
81
+ basedis->set_query(x);
82
+ }
83
+
84
+ /// compute distance of vector i to current query
85
+ float operator () (idx_t i) override {
86
+ return -(*basedis)(i);
87
+ }
88
+
89
+ /// compute distance between two stored vectors
90
+ float symmetric_dis (idx_t i, idx_t j) override {
91
+ return -basedis->symmetric_dis(i, j);
92
+ }
93
+
94
+ virtual ~NegativeDistanceComputer ()
95
+ {
96
+ delete basedis;
97
+ }
98
+
99
+ };
100
+
101
+ DistanceComputer *storage_distance_computer(const Index *storage)
102
+ {
103
+ if (storage->metric_type == METRIC_INNER_PRODUCT) {
104
+ return new NegativeDistanceComputer(storage->get_distance_computer());
105
+ } else {
106
+ return storage->get_distance_computer();
107
+ }
108
+ }
109
+
110
+
111
+
70
112
  void hnsw_add_vertices(IndexHNSW &index_hnsw,
71
113
  size_t n0,
72
114
  size_t n, const float *x,
@@ -152,7 +194,7 @@ void hnsw_add_vertices(IndexHNSW &index_hnsw,
152
194
  VisitedTable vt (ntotal);
153
195
 
154
196
  DistanceComputer *dis =
155
- index_hnsw.storage->get_distance_computer();
197
+ storage_distance_computer (index_hnsw.storage);
156
198
  ScopeDeleter1<DistanceComputer> del(dis);
157
199
  int prev_display = verbose && omp_get_thread_num() == 0 ? 0 : -1;
158
200
  size_t counter = 0;
@@ -210,8 +252,8 @@ void hnsw_add_vertices(IndexHNSW &index_hnsw,
210
252
  * IndexHNSW implementation
211
253
  **************************************************************/
212
254
 
213
- IndexHNSW::IndexHNSW(int d, int M):
214
- Index(d, METRIC_L2),
255
+ IndexHNSW::IndexHNSW(int d, int M, MetricType metric):
256
+ Index(d, metric),
215
257
  hnsw(M),
216
258
  own_fields(false),
217
259
  storage(nullptr),
@@ -258,7 +300,8 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
258
300
  #pragma omp parallel reduction(+ : nreorder)
259
301
  {
260
302
  VisitedTable vt (ntotal);
261
- DistanceComputer *dis = storage->get_distance_computer();
303
+
304
+ DistanceComputer *dis = storage_distance_computer(storage);
262
305
  ScopeDeleter1<DistanceComputer> del(dis);
263
306
 
264
307
  #pragma omp for
@@ -290,6 +333,14 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
290
333
  }
291
334
  InterruptCallback::check ();
292
335
  }
336
+
337
+ if (metric_type == METRIC_INNER_PRODUCT) {
338
+ // we need to revert the negated distances
339
+ for (size_t i = 0; i < k * n; i++) {
340
+ distances[i] = -distances[i];
341
+ }
342
+ }
343
+
293
344
  hnsw_stats.nreorder += nreorder;
294
345
  }
295
346
 
@@ -323,7 +374,7 @@ void IndexHNSW::shrink_level_0_neighbors(int new_size)
323
374
  {
324
375
  #pragma omp parallel
325
376
  {
326
- DistanceComputer *dis = storage->get_distance_computer();
377
+ DistanceComputer *dis = storage_distance_computer(storage);
327
378
  ScopeDeleter1<DistanceComputer> del(dis);
328
379
 
329
380
  #pragma omp for
@@ -367,7 +418,7 @@ void IndexHNSW::search_level_0(
367
418
  storage_idx_t ntotal = hnsw.levels.size();
368
419
  #pragma omp parallel
369
420
  {
370
- DistanceComputer *qdis = storage->get_distance_computer();
421
+ DistanceComputer *qdis = storage_distance_computer(storage);
371
422
  ScopeDeleter1<DistanceComputer> del(qdis);
372
423
 
373
424
  VisitedTable vt (ntotal);
@@ -436,7 +487,7 @@ void IndexHNSW::init_level_0_from_knngraph(
436
487
 
437
488
  #pragma omp parallel for
438
489
  for (idx_t i = 0; i < ntotal; i++) {
439
- DistanceComputer *qdis = storage->get_distance_computer();
490
+ DistanceComputer *qdis = storage_distance_computer(storage);
440
491
  float vec[d];
441
492
  storage->reconstruct(i, vec);
442
493
  qdis->set_query(vec);
@@ -480,7 +531,7 @@ void IndexHNSW::init_level_0_from_entry_points(
480
531
  {
481
532
  VisitedTable vt (ntotal);
482
533
 
483
- DistanceComputer *dis = storage->get_distance_computer();
534
+ DistanceComputer *dis = storage_distance_computer(storage);
484
535
  ScopeDeleter1<DistanceComputer> del(dis);
485
536
  float vec[storage->d];
486
537
 
@@ -518,7 +569,7 @@ void IndexHNSW::reorder_links()
518
569
  std::vector<float> distances (M);
519
570
  std::vector<size_t> order (M);
520
571
  std::vector<storage_idx_t> tmp (M);
521
- DistanceComputer *dis = storage->get_distance_computer();
572
+ DistanceComputer *dis = storage_distance_computer(storage);
522
573
  ScopeDeleter1<DistanceComputer> del(dis);
523
574
 
524
575
  #pragma omp for
@@ -826,8 +877,8 @@ IndexHNSWFlat::IndexHNSWFlat()
826
877
  is_trained = true;
827
878
  }
828
879
 
829
- IndexHNSWFlat::IndexHNSWFlat(int d, int M):
830
- IndexHNSW(new IndexFlatL2(d), M)
880
+ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric):
881
+ IndexHNSW(new IndexFlat(d, metric), M)
831
882
  {
832
883
  own_fields = true;
833
884
  is_trained = true;
@@ -860,8 +911,9 @@ void IndexHNSWPQ::train(idx_t n, const float* x)
860
911
  **************************************************************/
861
912
 
862
913
 
863
- IndexHNSWSQ::IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M):
864
- IndexHNSW (new IndexScalarQuantizer (d, qtype), M)
914
+ IndexHNSWSQ::IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M,
915
+ MetricType metric):
916
+ IndexHNSW (new IndexScalarQuantizer (d, qtype, metric), M)
865
917
  {
866
918
  is_trained = false;
867
919
  own_fields = true;
@@ -986,7 +1038,7 @@ void IndexHNSW2Level::search (idx_t n, const float *x, idx_t k,
986
1038
  #pragma omp parallel
987
1039
  {
988
1040
  VisitedTable vt (ntotal);
989
- DistanceComputer *dis = storage->get_distance_computer();
1041
+ DistanceComputer *dis = storage_distance_computer(storage);
990
1042
  ScopeDeleter1<DistanceComputer> del(dis);
991
1043
 
992
1044
  int candidates_size = hnsw.upper_beam;
@@ -79,7 +79,7 @@ struct IndexHNSW : Index {
79
79
 
80
80
  ReconstructFromNeighbors *reconstruct_from_neighbors;
81
81
 
82
- explicit IndexHNSW (int d = 0, int M = 32);
82
+ explicit IndexHNSW (int d = 0, int M = 32, MetricType metric = METRIC_L2);
83
83
  explicit IndexHNSW (Index *storage, int M = 32);
84
84
 
85
85
  ~IndexHNSW() override;
@@ -132,7 +132,7 @@ struct IndexHNSW : Index {
132
132
 
133
133
  struct IndexHNSWFlat : IndexHNSW {
134
134
  IndexHNSWFlat();
135
- IndexHNSWFlat(int d, int M);
135
+ IndexHNSWFlat(int d, int M, MetricType metric = METRIC_L2);
136
136
  };
137
137
 
138
138
  /** PQ index topped with with a HNSW structure to access elements
@@ -149,7 +149,7 @@ struct IndexHNSWPQ : IndexHNSW {
149
149
  */
150
150
  struct IndexHNSWSQ : IndexHNSW {
151
151
  IndexHNSWSQ();
152
- IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M);
152
+ IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M, MetricType metric = METRIC_L2);
153
153
  };
154
154
 
155
155
  /** 2-level code structure with fast random access
@@ -157,8 +157,7 @@ IndexIVF::IndexIVF (Index * quantizer, size_t d,
157
157
  code_size (code_size),
158
158
  nprobe (1),
159
159
  max_codes (0),
160
- parallel_mode (0),
161
- maintain_direct_map (false)
160
+ parallel_mode (0)
162
161
  {
163
162
  FAISS_THROW_IF_NOT (d == quantizer->d);
164
163
  is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
@@ -172,8 +171,7 @@ IndexIVF::IndexIVF (Index * quantizer, size_t d,
172
171
  IndexIVF::IndexIVF ():
173
172
  invlists (nullptr), own_invlists (false),
174
173
  code_size (0),
175
- nprobe (1), max_codes (0), parallel_mode (0),
176
- maintain_direct_map (false)
174
+ nprobe (1), max_codes (0), parallel_mode (0)
177
175
  {}
178
176
 
179
177
  void IndexIVF::add (idx_t n, const float * x)
@@ -199,6 +197,8 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
199
197
  }
200
198
 
201
199
  FAISS_THROW_IF_NOT (is_trained);
200
+ direct_map.check_can_add (xids);
201
+
202
202
  std::unique_ptr<idx_t []> idx(new idx_t[n]);
203
203
  quantizer->assign (n, x, idx.get());
204
204
  size_t nadd = 0, nminus1 = 0;
@@ -210,6 +210,8 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
210
210
  std::unique_ptr<uint8_t []> flat_codes(new uint8_t [n * code_size]);
211
211
  encode_vectors (n, x, idx.get(), flat_codes.get());
212
212
 
213
+ DirectMapAdd dm_adder(direct_map, n, xids);
214
+
213
215
  #pragma omp parallel reduction(+: nadd)
214
216
  {
215
217
  int nt = omp_get_num_threads();
@@ -220,13 +222,21 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
220
222
  idx_t list_no = idx [i];
221
223
  if (list_no >= 0 && list_no % nt == rank) {
222
224
  idx_t id = xids ? xids[i] : ntotal + i;
223
- invlists->add_entry (list_no, id,
224
- flat_codes.get() + i * code_size);
225
+ size_t ofs = invlists->add_entry (
226
+ list_no, id,
227
+ flat_codes.get() + i * code_size
228
+ );
229
+
230
+ dm_adder.add (i, list_no, ofs);
231
+
225
232
  nadd++;
233
+ } else if (rank == 0 && list_no == -1) {
234
+ dm_adder.add (i, -1, 0);
226
235
  }
227
236
  }
228
237
  }
229
238
 
239
+
230
240
  if (verbose) {
231
241
  printf(" added %ld / %ld vectors (%ld -1s)\n", nadd, n, nminus1);
232
242
  }
@@ -234,30 +244,18 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
234
244
  ntotal += n;
235
245
  }
236
246
 
237
-
238
- void IndexIVF::make_direct_map (bool new_maintain_direct_map)
247
+ void IndexIVF::make_direct_map (bool b)
239
248
  {
240
- // nothing to do
241
- if (new_maintain_direct_map == maintain_direct_map)
242
- return;
243
-
244
- if (new_maintain_direct_map) {
245
- direct_map.resize (ntotal, -1);
246
- for (size_t key = 0; key < nlist; key++) {
247
- size_t list_size = invlists->list_size (key);
248
- ScopedIds idlist (invlists, key);
249
-
250
- for (long ofs = 0; ofs < list_size; ofs++) {
251
- FAISS_THROW_IF_NOT_MSG (
252
- 0 <= idlist [ofs] && idlist[ofs] < ntotal,
253
- "direct map supported only for seuquential ids");
254
- direct_map [idlist [ofs]] = key << 32 | ofs;
255
- }
256
- }
249
+ if (b) {
250
+ direct_map.set_type (DirectMap::Array, invlists, ntotal);
257
251
  } else {
258
- direct_map.clear ();
252
+ direct_map.set_type (DirectMap::NoMap, invlists, ntotal);
259
253
  }
260
- maintain_direct_map = new_maintain_direct_map;
254
+ }
255
+
256
+ void IndexIVF::set_direct_map_type (DirectMap::Type type)
257
+ {
258
+ direct_map.set_type (type, invlists, ntotal);
261
259
  }
262
260
 
263
261
 
@@ -298,10 +296,13 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
298
296
 
299
297
  bool interrupt = false;
300
298
 
299
+ int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
300
+ bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
301
+
301
302
  // don't start parallel section if single query
302
303
  bool do_parallel =
303
- parallel_mode == 0 ? n > 1 :
304
- parallel_mode == 1 ? nprobe > 1 :
304
+ pmode == 0 ? n > 1 :
305
+ pmode == 1 ? nprobe > 1 :
305
306
  nprobe * n > 1;
306
307
 
307
308
  #pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
@@ -318,6 +319,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
318
319
  // intialize + reorder a result heap
319
320
 
320
321
  auto init_result = [&](float *simi, idx_t *idxi) {
322
+ if (!do_heap_init) return;
321
323
  if (metric_type == METRIC_INNER_PRODUCT) {
322
324
  heap_heapify<HeapForIP> (k, simi, idxi);
323
325
  } else {
@@ -326,6 +328,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
326
328
  };
327
329
 
328
330
  auto reorder_result = [&] (float *simi, idx_t *idxi) {
331
+ if (!do_heap_init) return;
329
332
  if (metric_type == METRIC_INNER_PRODUCT) {
330
333
  heap_reorder<HeapForIP> (k, simi, idxi);
331
334
  } else {
@@ -377,7 +380,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
377
380
  * Actual loops, depending on parallel_mode
378
381
  ****************************************************/
379
382
 
380
- if (parallel_mode == 0) {
383
+ if (pmode == 0) {
381
384
 
382
385
  #pragma omp for
383
386
  for (size_t i = 0; i < n; i++) {
@@ -417,7 +420,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
417
420
  }
418
421
 
419
422
  } // parallel for
420
- } else if (parallel_mode == 1) {
423
+ } else if (pmode == 1) {
421
424
  std::vector <idx_t> local_idx (k);
422
425
  std::vector <float> local_dis (k);
423
426
 
@@ -460,7 +463,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
460
463
  }
461
464
  } else {
462
465
  FAISS_THROW_FMT ("parallel_mode %d not supported\n",
463
- parallel_mode);
466
+ pmode);
464
467
  }
465
468
  } // parallel section
466
469
 
@@ -608,13 +611,8 @@ InvertedListScanner *IndexIVF::get_InvertedListScanner (
608
611
 
609
612
  void IndexIVF::reconstruct (idx_t key, float* recons) const
610
613
  {
611
- FAISS_THROW_IF_NOT_MSG (direct_map.size() == ntotal,
612
- "direct map is not initialized");
613
- FAISS_THROW_IF_NOT_MSG (key >= 0 && key < direct_map.size(),
614
- "invalid key");
615
- idx_t list_no = direct_map[key] >> 32;
616
- idx_t offset = direct_map[key] & 0xffffffff;
617
- reconstruct_from_offset (list_no, offset, recons);
614
+ idx_t lo = direct_map.get (key);
615
+ reconstruct_from_offset (lo_listno(lo), lo_offset(lo), recons);
618
616
  }
619
617
 
620
618
 
@@ -682,8 +680,8 @@ void IndexIVF::search_and_reconstruct (idx_t n, const float *x, idx_t k,
682
680
  // Fill with NaNs
683
681
  memset(reconstructed, -1, sizeof(*reconstructed) * d);
684
682
  } else {
685
- int list_no = key >> 32;
686
- int offset = key & 0xffffffff;
683
+ int list_no = lo_listno (key);
684
+ int offset = lo_offset (key);
687
685
 
688
686
  // Update label to the actual id
689
687
  labels[ij] = invlists->get_single_id (list_no, offset);
@@ -711,42 +709,41 @@ void IndexIVF::reset ()
711
709
 
712
710
  size_t IndexIVF::remove_ids (const IDSelector & sel)
713
711
  {
714
- FAISS_THROW_IF_NOT_MSG (!maintain_direct_map,
715
- "direct map remove not implemented");
716
-
717
- std::vector<idx_t> toremove(nlist);
718
-
719
- #pragma omp parallel for
720
- for (idx_t i = 0; i < nlist; i++) {
721
- idx_t l0 = invlists->list_size (i), l = l0, j = 0;
722
- ScopedIds idsi (invlists, i);
723
- while (j < l) {
724
- if (sel.is_member (idsi[j])) {
725
- l--;
726
- invlists->update_entry (
727
- i, j,
728
- invlists->get_single_id (i, l),
729
- ScopedCodes (invlists, i, l).get());
730
- } else {
731
- j++;
732
- }
733
- }
734
- toremove[i] = l0 - l;
735
- }
736
- // this will not run well in parallel on ondisk because of possible shrinks
737
- size_t nremove = 0;
738
- for (idx_t i = 0; i < nlist; i++) {
739
- if (toremove[i] > 0) {
740
- nremove += toremove[i];
741
- invlists->resize(
742
- i, invlists->list_size(i) - toremove[i]);
743
- }
744
- }
712
+ size_t nremove = direct_map.remove_ids (sel, invlists);
745
713
  ntotal -= nremove;
746
714
  return nremove;
747
715
  }
748
716
 
749
717
 
718
+ void IndexIVF::update_vectors (int n, const idx_t *new_ids, const float *x)
719
+ {
720
+
721
+ if (direct_map.type == DirectMap::Hashtable) {
722
+ // just remove then add
723
+ IDSelectorArray sel(n, new_ids);
724
+ size_t nremove = remove_ids (sel);
725
+ FAISS_THROW_IF_NOT_MSG (nremove == n,
726
+ "did not find all entries to remove");
727
+ add_with_ids (n, x, new_ids);
728
+ return;
729
+ }
730
+
731
+ FAISS_THROW_IF_NOT (direct_map.type == DirectMap::Array);
732
+ // here it is more tricky because we don't want to introduce holes
733
+ // in continuous range of ids
734
+
735
+ FAISS_THROW_IF_NOT (is_trained);
736
+ std::vector<idx_t> assign (n);
737
+ quantizer->assign (n, x, assign.data());
738
+
739
+ std::vector<uint8_t> flat_codes (n * code_size);
740
+ encode_vectors (n, x, assign.data(), flat_codes.data());
741
+
742
+ direct_map.update_codes (invlists, n, new_ids, assign.data(), flat_codes.data());
743
+
744
+ }
745
+
746
+
750
747
 
751
748
 
752
749
  void IndexIVF::train (idx_t n, const float *x)
@@ -779,15 +776,14 @@ void IndexIVF::check_compatible_for_merge (const IndexIVF &other) const
779
776
  FAISS_THROW_IF_NOT (other.code_size == code_size);
780
777
  FAISS_THROW_IF_NOT_MSG (typeid (*this) == typeid (other),
781
778
  "can only merge indexes of the same type");
779
+ FAISS_THROW_IF_NOT_MSG (this->direct_map.no() && other.direct_map.no(),
780
+ "merge direct_map not implemented");
782
781
  }
783
782
 
784
783
 
785
784
  void IndexIVF::merge_from (IndexIVF &other, idx_t add_id)
786
785
  {
787
786
  check_compatible_for_merge (other);
788
- FAISS_THROW_IF_NOT_MSG ((!maintain_direct_map &&
789
- !other.maintain_direct_map),
790
- "direct map copy not implemented");
791
787
 
792
788
  invlists->merge_from (other.invlists, add_id);
793
789
 
@@ -817,7 +813,7 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
817
813
 
818
814
  FAISS_THROW_IF_NOT (nlist == other.nlist);
819
815
  FAISS_THROW_IF_NOT (code_size == other.code_size);
820
- FAISS_THROW_IF_NOT (!other.maintain_direct_map);
816
+ FAISS_THROW_IF_NOT (other.direct_map.no());
821
817
  FAISS_THROW_IF_NOT_FMT (
822
818
  subset_type == 0 || subset_type == 1 || subset_type == 2,
823
819
  "subset type %d not implemented", subset_type);