faiss 0.4.1 → 0.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +39 -29
  5. data/vendor/faiss/faiss/Clustering.cpp +4 -2
  6. data/vendor/faiss/faiss/IVFlib.cpp +14 -7
  7. data/vendor/faiss/faiss/Index.h +72 -3
  8. data/vendor/faiss/faiss/Index2Layer.cpp +2 -4
  9. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +0 -1
  10. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +1 -0
  11. data/vendor/faiss/faiss/IndexBinary.h +46 -3
  12. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +118 -4
  13. data/vendor/faiss/faiss/IndexBinaryHNSW.h +41 -0
  14. data/vendor/faiss/faiss/IndexBinaryHash.cpp +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +18 -7
  16. data/vendor/faiss/faiss/IndexBinaryIVF.h +5 -1
  17. data/vendor/faiss/faiss/IndexFlat.cpp +6 -4
  18. data/vendor/faiss/faiss/IndexHNSW.cpp +65 -24
  19. data/vendor/faiss/faiss/IndexHNSW.h +10 -1
  20. data/vendor/faiss/faiss/IndexIDMap.cpp +96 -18
  21. data/vendor/faiss/faiss/IndexIDMap.h +20 -0
  22. data/vendor/faiss/faiss/IndexIVF.cpp +28 -10
  23. data/vendor/faiss/faiss/IndexIVF.h +16 -1
  24. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -16
  25. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +18 -6
  26. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +33 -21
  27. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +16 -6
  28. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +24 -15
  29. data/vendor/faiss/faiss/IndexIVFFastScan.h +4 -2
  30. data/vendor/faiss/faiss/IndexIVFFlat.cpp +59 -43
  31. data/vendor/faiss/faiss/IndexIVFFlat.h +10 -2
  32. data/vendor/faiss/faiss/IndexIVFPQ.cpp +16 -3
  33. data/vendor/faiss/faiss/IndexIVFPQ.h +8 -1
  34. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +14 -6
  35. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -1
  36. data/vendor/faiss/faiss/IndexIVFPQR.cpp +14 -4
  37. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  38. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +28 -3
  39. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +8 -1
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +9 -2
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  42. data/vendor/faiss/faiss/IndexLattice.cpp +8 -4
  43. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -7
  44. data/vendor/faiss/faiss/IndexNSG.cpp +3 -3
  45. data/vendor/faiss/faiss/IndexPQ.cpp +0 -1
  46. data/vendor/faiss/faiss/IndexPQ.h +1 -0
  47. data/vendor/faiss/faiss/IndexPQFastScan.cpp +0 -2
  48. data/vendor/faiss/faiss/IndexPreTransform.cpp +4 -2
  49. data/vendor/faiss/faiss/IndexRefine.cpp +11 -6
  50. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +16 -4
  51. data/vendor/faiss/faiss/IndexScalarQuantizer.h +10 -3
  52. data/vendor/faiss/faiss/IndexShards.cpp +7 -6
  53. data/vendor/faiss/faiss/MatrixStats.cpp +16 -8
  54. data/vendor/faiss/faiss/MetaIndexes.cpp +12 -6
  55. data/vendor/faiss/faiss/MetricType.h +5 -3
  56. data/vendor/faiss/faiss/clone_index.cpp +2 -4
  57. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +6 -0
  58. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +9 -4
  59. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +32 -10
  60. data/vendor/faiss/faiss/gpu/GpuIndex.h +88 -0
  61. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +125 -0
  62. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +39 -4
  63. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +3 -3
  64. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -1
  65. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +3 -2
  66. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +41 -0
  67. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +6 -3
  68. data/vendor/faiss/faiss/impl/HNSW.cpp +34 -19
  69. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -1
  70. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +2 -3
  71. data/vendor/faiss/faiss/impl/NNDescent.cpp +17 -9
  72. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +42 -21
  73. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +6 -24
  74. data/vendor/faiss/faiss/impl/ResultHandler.h +56 -47
  75. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +28 -15
  76. data/vendor/faiss/faiss/impl/index_read.cpp +36 -11
  77. data/vendor/faiss/faiss/impl/index_write.cpp +19 -6
  78. data/vendor/faiss/faiss/impl/io.cpp +9 -5
  79. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +18 -11
  80. data/vendor/faiss/faiss/impl/mapped_io.cpp +4 -7
  81. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +0 -1
  82. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +0 -1
  83. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +6 -6
  84. data/vendor/faiss/faiss/impl/zerocopy_io.cpp +1 -1
  85. data/vendor/faiss/faiss/impl/zerocopy_io.h +2 -2
  86. data/vendor/faiss/faiss/index_factory.cpp +49 -33
  87. data/vendor/faiss/faiss/index_factory.h +8 -2
  88. data/vendor/faiss/faiss/index_io.h +0 -3
  89. data/vendor/faiss/faiss/invlists/DirectMap.cpp +2 -1
  90. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +12 -6
  91. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +8 -4
  92. data/vendor/faiss/faiss/utils/Heap.cpp +15 -8
  93. data/vendor/faiss/faiss/utils/Heap.h +23 -12
  94. data/vendor/faiss/faiss/utils/distances.cpp +42 -21
  95. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  96. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +1 -1
  97. data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -3
  98. data/vendor/faiss/faiss/utils/extra_distances-inl.h +27 -4
  99. data/vendor/faiss/faiss/utils/extra_distances.cpp +8 -4
  100. data/vendor/faiss/faiss/utils/hamming.cpp +20 -10
  101. data/vendor/faiss/faiss/utils/partitioning.cpp +8 -4
  102. data/vendor/faiss/faiss/utils/quantize_lut.cpp +17 -9
  103. data/vendor/faiss/faiss/utils/rabitq_simd.h +539 -0
  104. data/vendor/faiss/faiss/utils/random.cpp +14 -7
  105. data/vendor/faiss/faiss/utils/utils.cpp +0 -3
  106. metadata +5 -2
@@ -26,6 +26,8 @@
26
26
  #include <faiss/utils/hamming.h>
27
27
  #include <faiss/utils/random.h>
28
28
 
29
+ #include <random>
30
+
29
31
  namespace faiss {
30
32
 
31
33
  /**************************************************************
@@ -98,7 +100,9 @@ void hnsw_add_vertices(
98
100
 
99
101
  int i1 = n;
100
102
 
101
- for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) {
103
+ for (int pt_level = hist.size() - 1;
104
+ pt_level >= int(!index_hnsw.init_level0);
105
+ pt_level--) {
102
106
  int i0 = i1 - hist[pt_level];
103
107
 
104
108
  if (verbose) {
@@ -125,7 +129,13 @@ void hnsw_add_vertices(
125
129
  dis->set_query(
126
130
  (float*)(x + (pt_id - n0) * index_hnsw.code_size));
127
131
 
128
- hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt);
132
+ hnsw.add_with_locks(
133
+ *dis,
134
+ pt_level,
135
+ pt_id,
136
+ locks,
137
+ vt,
138
+ index_hnsw.keep_max_size_level0 && (pt_level == 0));
129
139
 
130
140
  if (prev_display >= 0 && i - i0 > prev_display + 10000) {
131
141
  prev_display = i - i0;
@@ -136,14 +146,19 @@ void hnsw_add_vertices(
136
146
  }
137
147
  i1 = i0;
138
148
  }
139
- FAISS_ASSERT(i1 == 0);
149
+ if (index_hnsw.init_level0) {
150
+ FAISS_ASSERT(i1 == 0);
151
+ } else {
152
+ FAISS_ASSERT((i1 - hist[0]) == 0);
153
+ }
140
154
  }
141
155
  if (verbose) {
142
156
  printf("Done in %.3f ms\n", getmillisecs() - t0);
143
157
  }
144
158
 
145
- for (int i = 0; i < ntotal; i++)
159
+ for (int i = 0; i < ntotal; i++) {
146
160
  omp_destroy_lock(&locks[i]);
161
+ }
147
162
  }
148
163
 
149
164
  } // anonymous namespace
@@ -296,4 +311,103 @@ DistanceComputer* IndexBinaryHNSW::get_distance_computer() const {
296
311
  return dispatch_HammingComputer(code_size, bd, flat_storage);
297
312
  }
298
313
 
314
+ /**************************************************************
315
+ * IndexBinaryHNSWCagra implementation
316
+ **************************************************************/
317
+
318
+ IndexBinaryHNSWCagra::IndexBinaryHNSWCagra() : IndexBinaryHNSW() {
319
+ storage = nullptr;
320
+ }
321
+
322
+ IndexBinaryHNSWCagra::IndexBinaryHNSWCagra(int d, int M)
323
+ : IndexBinaryHNSW(d, M) {
324
+ init_level0 = true;
325
+ keep_max_size_level0 = true;
326
+ }
327
+
328
+ void IndexBinaryHNSWCagra::add(idx_t n, const uint8_t* x) {
329
+ FAISS_THROW_IF_NOT_MSG(
330
+ !base_level_only,
331
+ "Cannot add vectors when base_level_only is set to True");
332
+
333
+ IndexBinaryHNSW::add(n, x);
334
+ }
335
+
336
+ void IndexBinaryHNSWCagra::search(
337
+ idx_t n,
338
+ const uint8_t* x,
339
+ idx_t k,
340
+ int32_t* distances,
341
+ idx_t* labels,
342
+ const SearchParameters* params) const {
343
+ if (!base_level_only) {
344
+ IndexBinaryHNSW::search(n, x, k, distances, labels, params);
345
+ } else {
346
+ float* distances_f = (float*)distances;
347
+
348
+ using RH = HeapBlockResultHandler<HNSW::C>;
349
+ RH bres(n, distances_f, labels, k);
350
+
351
+ std::vector<storage_idx_t> nearest(n);
352
+ std::vector<float> nearest_d(n);
353
+
354
+ #pragma omp parallel for
355
+ for (idx_t i = 0; i < n; i++) {
356
+ std::unique_ptr<DistanceComputer> dis(get_distance_computer());
357
+ dis->set_query((float*)(x + i * code_size));
358
+
359
+ nearest[i] = -1;
360
+ nearest_d[i] = std::numeric_limits<float>::max();
361
+
362
+ std::random_device rd;
363
+ std::mt19937 gen(rd());
364
+ std::uniform_int_distribution<idx_t> distrib(0, this->ntotal - 1);
365
+
366
+ for (idx_t j = 0; j < num_base_level_search_entrypoints; j++) {
367
+ auto idx = distrib(gen);
368
+ float distance = (*dis)(idx);
369
+
370
+ if (distance < nearest_d[i]) {
371
+ nearest[i] = idx;
372
+ nearest_d[i] = distance;
373
+ }
374
+ }
375
+ FAISS_THROW_IF_NOT_MSG(
376
+ nearest[i] >= 0, "Could not find a valid entrypoint.");
377
+ }
378
+
379
+ #pragma omp parallel
380
+ {
381
+ VisitedTable vt(ntotal);
382
+ std::unique_ptr<DistanceComputer> dis(get_distance_computer());
383
+ HNSWStats search_stats;
384
+ RH::SingleResultHandler res(bres);
385
+
386
+ #pragma omp for
387
+ for (idx_t i = 0; i < n; i++) {
388
+ res.begin(i);
389
+ dis->set_query((float*)(x + i * code_size));
390
+
391
+ hnsw.search_level_0(
392
+ *dis,
393
+ res,
394
+ 1,
395
+ &nearest[i],
396
+ &nearest_d[i],
397
+ 1, // search_type
398
+ search_stats,
399
+ vt,
400
+ params);
401
+
402
+ res.end();
403
+ }
404
+ }
405
+
406
+ #pragma omp parallel for
407
+ for (int i = 0; i < n * k; ++i) {
408
+ distances[i] = std::round(distances_f[i]);
409
+ }
410
+ }
411
+ }
412
+
299
413
  } // namespace faiss
@@ -28,6 +28,18 @@ struct IndexBinaryHNSW : IndexBinary {
28
28
  bool own_fields;
29
29
  IndexBinary* storage;
30
30
 
31
+ // When set to false, level 0 in the knn graph is not initialized.
32
+ // This option is used by GpuIndexBinaryCagra::copyTo(IndexBinaryHNSW*)
33
+ // as level 0 knn graph is copied over from the index built by
34
+ // GpuIndexBinaryCagra.
35
+ bool init_level0 = true;
36
+
37
+ // When set to true, all neighbors in level 0 are filled up
38
+ // to the maximum size allowed (2 * M). This option is used by
39
+ // IndexBinaryHHNSW to create a full base layer graph that is
40
+ // used when GpuIndexBinaryCagra::copyFrom(IndexBinaryHNSW*) is called.
41
+ bool keep_max_size_level0 = false;
42
+
31
43
  explicit IndexBinaryHNSW();
32
44
  explicit IndexBinaryHNSW(int d, int M = 32);
33
45
  explicit IndexBinaryHNSW(IndexBinary* storage, int M = 32);
@@ -55,4 +67,33 @@ struct IndexBinaryHNSW : IndexBinary {
55
67
  void reset() override;
56
68
  };
57
69
 
70
+ struct IndexBinaryHNSWCagra : IndexBinaryHNSW {
71
+ IndexBinaryHNSWCagra();
72
+ IndexBinaryHNSWCagra(int d, int M);
73
+
74
+ /// When set to true, the index is immutable.
75
+ /// This option is used to copy the knn graph from GpuIndexBinaryCagra
76
+ /// to the base level of IndexBinaryHNSWCagra without adding upper levels.
77
+ /// Doing so enables to search the HNSW index, but removes the
78
+ /// ability to add vectors.
79
+ bool base_level_only = false;
80
+
81
+ /// When `base_level_only` is set to `True`, the search function
82
+ /// searches only the base level knn graph of the HNSW index.
83
+ /// This parameter selects the entry point by randomly selecting
84
+ /// some points and using the best one.
85
+ int num_base_level_search_entrypoints = 32;
86
+
87
+ void add(idx_t n, const uint8_t* x) override;
88
+
89
+ /// entry point for search
90
+ void search(
91
+ idx_t n,
92
+ const uint8_t* x,
93
+ idx_t k,
94
+ int32_t* distances,
95
+ idx_t* labels,
96
+ const SearchParameters* params = nullptr) const override;
97
+ };
98
+
58
99
  } // namespace faiss
@@ -11,7 +11,6 @@
11
11
 
12
12
  #include <cinttypes>
13
13
  #include <cstdio>
14
- #include <memory>
15
14
  #include <unordered_set>
16
15
 
17
16
  #include <faiss/utils/hamming.h>
@@ -26,9 +26,16 @@
26
26
 
27
27
  namespace faiss {
28
28
 
29
- IndexBinaryIVF::IndexBinaryIVF(IndexBinary* quantizer, size_t d, size_t nlist)
29
+ IndexBinaryIVF::IndexBinaryIVF(
30
+ IndexBinary* quantizer,
31
+ size_t d,
32
+ size_t nlist,
33
+ bool own_invlists)
30
34
  : IndexBinary(d),
31
- invlists(new ArrayInvertedLists(nlist, code_size)),
35
+ invlists(
36
+ own_invlists ? new ArrayInvertedLists(nlist, code_size)
37
+ : nullptr),
38
+ own_invlists(own_invlists),
32
39
  quantizer(quantizer),
33
40
  nlist(nlist) {
34
41
  FAISS_THROW_IF_NOT(d == quantizer->d);
@@ -283,7 +290,7 @@ void IndexBinaryIVF::check_compatible_for_merge(
283
290
  direct_map.no() && other->direct_map.no(),
284
291
  "direct map copy not implemented");
285
292
  FAISS_THROW_IF_NOT_MSG(
286
- typeid(*this) == typeid(other),
293
+ typeid(*this) == typeid(*other),
287
294
  "can only merge indexes of the same type");
288
295
  }
289
296
 
@@ -444,8 +451,9 @@ void search_knn_hamming_heap(
444
451
  list_size, scodes.get(), ids, simi, idxi, k);
445
452
 
446
453
  nscan += list_size;
447
- if (max_codes && nscan >= max_codes)
454
+ if (max_codes && nscan >= max_codes) {
448
455
  break;
456
+ }
449
457
  }
450
458
 
451
459
  ndis += nscan;
@@ -532,8 +540,9 @@ void search_knn_hamming_count(
532
540
  }
533
541
 
534
542
  nscan += list_size;
535
- if (max_codes && nscan >= max_codes)
543
+ if (max_codes && nscan >= max_codes) {
536
544
  break;
545
+ }
537
546
  }
538
547
  ndis += nscan;
539
548
 
@@ -850,8 +859,9 @@ void IndexBinaryIVF::range_search_preassigned(
850
859
 
851
860
  auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
852
861
  idx_t key = assign[i * nprobe_2 + ik]; /* select the list */
853
- if (key < 0)
862
+ if (key < 0) {
854
863
  return;
864
+ }
855
865
  FAISS_THROW_IF_NOT_FMT(
856
866
  key < (idx_t)nlist,
857
867
  "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
@@ -860,8 +870,9 @@ void IndexBinaryIVF::range_search_preassigned(
860
870
  nlist);
861
871
  const size_t list_size = invlists->list_size(key);
862
872
 
863
- if (list_size == 0)
873
+ if (list_size == 0) {
864
874
  return;
875
+ }
865
876
 
866
877
  InvertedLists::ScopedCodes scodes(invlists, key);
867
878
  InvertedLists::ScopedIds ids(invlists, key);
@@ -68,7 +68,11 @@ struct IndexBinaryIVF : IndexBinary {
68
68
  * identifier. The pointer is borrowed: the quantizer should not
69
69
  * be deleted while the IndexBinaryIVF is in use.
70
70
  */
71
- IndexBinaryIVF(IndexBinary* quantizer, size_t d, size_t nlist);
71
+ IndexBinaryIVF(
72
+ IndexBinary* quantizer,
73
+ size_t d,
74
+ size_t nlist,
75
+ bool own_invlists = true);
72
76
 
73
77
  IndexBinaryIVF();
74
78
 
@@ -239,6 +239,7 @@ FlatCodesDistanceComputer* IndexFlat::get_FlatCodesDistanceComputer() const {
239
239
  }
240
240
 
241
241
  void IndexFlat::reconstruct(idx_t key, float* recons) const {
242
+ FAISS_THROW_IF_NOT(key < ntotal);
242
243
  memcpy(recons, &(codes[key * code_size]), code_size);
243
244
  }
244
245
 
@@ -399,8 +400,9 @@ void IndexFlat1D::update_permutation() {
399
400
 
400
401
  void IndexFlat1D::add(idx_t n, const float* x) {
401
402
  IndexFlatL2::add(n, x);
402
- if (continuous_update)
403
+ if (continuous_update) {
403
404
  update_permutation();
405
+ }
404
406
  }
405
407
 
406
408
  void IndexFlat1D::reset() {
@@ -452,10 +454,11 @@ void IndexFlat1D::search(
452
454
 
453
455
  while (i0 + 1 < i1) {
454
456
  idx_t imed = (i0 + i1) / 2;
455
- if (xb[perm[imed]] <= q)
457
+ if (xb[perm[imed]] <= q) {
456
458
  i0 = imed;
457
- else
459
+ } else {
458
460
  i1 = imed;
461
+ }
459
462
  }
460
463
 
461
464
  // query is between xb[perm[i0]] and xb[perm[i1]]
@@ -516,5 +519,4 @@ void IndexFlat1D::search(
516
519
  done:;
517
520
  }
518
521
  }
519
-
520
522
  } // namespace faiss
@@ -19,6 +19,7 @@
19
19
  #include <random>
20
20
 
21
21
  #include <cstdint>
22
+ #include "faiss/Index.h"
22
23
 
23
24
  #include <faiss/Index2Layer.h>
24
25
  #include <faiss/IndexFlat.h>
@@ -81,8 +82,9 @@ void hnsw_add_vertices(
81
82
  }
82
83
 
83
84
  std::vector<omp_lock_t> locks(ntotal);
84
- for (int i = 0; i < ntotal; i++)
85
+ for (int i = 0; i < ntotal; i++) {
85
86
  omp_init_lock(&locks[i]);
87
+ }
86
88
 
87
89
  // add vectors from highest to lowest level
88
90
  std::vector<int> hist;
@@ -94,8 +96,9 @@ void hnsw_add_vertices(
94
96
  for (int i = 0; i < n; i++) {
95
97
  storage_idx_t pt_id = i + n0;
96
98
  int pt_level = hnsw.levels[pt_id] - 1;
97
- while (pt_level >= hist.size())
99
+ while (pt_level >= hist.size()) {
98
100
  hist.push_back(0);
101
+ }
99
102
  hist[pt_level]++;
100
103
  }
101
104
 
@@ -131,8 +134,9 @@ void hnsw_add_vertices(
131
134
  }
132
135
 
133
136
  // random permutation to get rid of dataset order bias
134
- for (int j = i0; j < i1; j++)
137
+ for (int j = i0; j < i1; j++) {
135
138
  std::swap(order[j], order[j + rng2.rand_int(i1 - j)]);
139
+ }
136
140
 
137
141
  bool interrupt = false;
138
142
 
@@ -377,8 +381,9 @@ void IndexHNSW::shrink_level_0_neighbors(int new_size) {
377
381
 
378
382
  for (size_t j = begin; j < end; j++) {
379
383
  int v1 = hnsw.neighbors[j];
380
- if (v1 < 0)
384
+ if (v1 < 0) {
381
385
  break;
386
+ }
382
387
  initial_list.emplace(dis->symmetric_dis(i, v1), v1);
383
388
 
384
389
  // initial_list.emplace(qdis(v1), v1);
@@ -389,10 +394,11 @@ void IndexHNSW::shrink_level_0_neighbors(int new_size) {
389
394
  *dis, initial_list, shrunk_list, new_size);
390
395
 
391
396
  for (size_t j = begin; j < end; j++) {
392
- if (j - begin < shrunk_list.size())
397
+ if (j - begin < shrunk_list.size()) {
393
398
  hnsw.neighbors[j] = shrunk_list[j - begin].id;
394
- else
399
+ } else {
395
400
  hnsw.neighbors[j] = -1;
401
+ }
396
402
  }
397
403
  }
398
404
  }
@@ -472,10 +478,12 @@ void IndexHNSW::init_level_0_from_knngraph(
472
478
 
473
479
  for (size_t j = 0; j < k; j++) {
474
480
  int v1 = I[i * k + j];
475
- if (v1 == i)
481
+ if (v1 == i) {
476
482
  continue;
477
- if (v1 < 0)
483
+ }
484
+ if (v1 < 0) {
478
485
  break;
486
+ }
479
487
  initial_list.emplace(D[i * k + j], v1);
480
488
  }
481
489
 
@@ -486,10 +494,11 @@ void IndexHNSW::init_level_0_from_knngraph(
486
494
  hnsw.neighbor_range(i, 0, &begin, &end);
487
495
 
488
496
  for (size_t j = begin; j < end; j++) {
489
- if (j - begin < shrunk_list.size())
497
+ if (j - begin < shrunk_list.size()) {
490
498
  hnsw.neighbors[j] = shrunk_list[j - begin].id;
491
- else
499
+ } else {
492
500
  hnsw.neighbors[j] = -1;
501
+ }
493
502
  }
494
503
  }
495
504
  }
@@ -499,8 +508,9 @@ void IndexHNSW::init_level_0_from_entry_points(
499
508
  const storage_idx_t* points,
500
509
  const storage_idx_t* nearests) {
501
510
  std::vector<omp_lock_t> locks(ntotal);
502
- for (int i = 0; i < ntotal; i++)
511
+ for (int i = 0; i < ntotal; i++) {
503
512
  omp_init_lock(&locks[i]);
513
+ }
504
514
 
505
515
  #pragma omp parallel
506
516
  {
@@ -530,8 +540,9 @@ void IndexHNSW::init_level_0_from_entry_points(
530
540
  printf("\n");
531
541
  }
532
542
 
533
- for (int i = 0; i < ntotal; i++)
543
+ for (int i = 0; i < ntotal; i++) {
534
544
  omp_destroy_lock(&locks[i]);
545
+ }
535
546
  }
536
547
 
537
548
  void IndexHNSW::reorder_links() {
@@ -578,8 +589,9 @@ void IndexHNSW::link_singletons() {
578
589
  hnsw.neighbor_range(i, 0, &begin, &end);
579
590
  for (size_t j = begin; j < end; j++) {
580
591
  storage_idx_t ni = hnsw.neighbors[j];
581
- if (ni >= 0)
592
+ if (ni >= 0) {
582
593
  seen[ni] = true;
594
+ }
583
595
  }
584
596
  }
585
597
 
@@ -589,8 +601,9 @@ void IndexHNSW::link_singletons() {
589
601
  if (!seen[i]) {
590
602
  singletons.push_back(i);
591
603
  n_sing++;
592
- if (hnsw.levels[i] > 1)
604
+ if (hnsw.levels[i] > 1) {
593
605
  n_sing_l1++;
606
+ }
594
607
  }
595
608
  }
596
609
 
@@ -722,8 +735,9 @@ int search_from_candidates_2(
722
735
 
723
736
  for (size_t j = begin; j < end; j++) {
724
737
  int v1 = hnsw.neighbors[j];
725
- if (v1 < 0)
738
+ if (v1 < 0) {
726
739
  break;
740
+ }
727
741
  if (vt.visited[v1] == vt.visno + 1) {
728
742
  // nothing to do
729
743
  } else {
@@ -749,8 +763,9 @@ int search_from_candidates_2(
749
763
  }
750
764
 
751
765
  stats.n1++;
752
- if (candidates.size() == 0)
766
+ if (candidates.size() == 0) {
753
767
  stats.n2++;
768
+ }
754
769
 
755
770
  return nres;
756
771
  }
@@ -814,8 +829,9 @@ void IndexHNSW2Level::search(
814
829
 
815
830
  for (int j = 0; j < nprobe; j++) {
816
831
  idx_t key = coarse_assign[j + i * nprobe];
817
- if (key < 0)
832
+ if (key < 0) {
818
833
  break;
834
+ }
819
835
  size_t list_length = index_ivfpq->get_list_size(key);
820
836
  const idx_t* ids = index_ivfpq->invlists->get_ids(key);
821
837
 
@@ -827,8 +843,9 @@ void IndexHNSW2Level::search(
827
843
  candidates.clear();
828
844
 
829
845
  for (int j = 0; j < k; j++) {
830
- if (idxi[j] < 0)
846
+ if (idxi[j] < 0) {
831
847
  break;
848
+ }
832
849
  candidates.push(idxi[j], simi[j]);
833
850
  }
834
851
 
@@ -893,15 +910,31 @@ IndexHNSWCagra::IndexHNSWCagra() {
893
910
  is_trained = true;
894
911
  }
895
912
 
896
- IndexHNSWCagra::IndexHNSWCagra(int d, int M, MetricType metric)
897
- : IndexHNSW(
898
- (metric == METRIC_L2)
899
- ? static_cast<IndexFlat*>(new IndexFlatL2(d))
900
- : static_cast<IndexFlat*>(new IndexFlatIP(d)),
901
- M) {
913
+ IndexHNSWCagra::IndexHNSWCagra(
914
+ int d,
915
+ int M,
916
+ MetricType metric,
917
+ NumericType numeric_type)
918
+ : IndexHNSW(d, M, metric) {
902
919
  FAISS_THROW_IF_NOT_MSG(
903
920
  ((metric == METRIC_L2) || (metric == METRIC_INNER_PRODUCT)),
904
921
  "unsupported metric type for IndexHNSWCagra");
922
+ numeric_type_ = numeric_type;
923
+ if (numeric_type == NumericType::Float32) {
924
+ // Use flat storage with full precision for fp32
925
+ storage = (metric == METRIC_L2)
926
+ ? static_cast<Index*>(new IndexFlatL2(d))
927
+ : static_cast<Index*>(new IndexFlatIP(d));
928
+ } else if (numeric_type == NumericType::Float16) {
929
+ auto qtype = ScalarQuantizer::QT_fp16;
930
+ storage = new IndexScalarQuantizer(d, qtype, metric);
931
+ } else {
932
+ FAISS_THROW_MSG(
933
+ "Unsupported numeric_type: only F16 and F32 are supported for IndexHNSWCagra");
934
+ }
935
+
936
+ metric_arg = storage->metric_arg;
937
+
905
938
  own_fields = true;
906
939
  is_trained = true;
907
940
  init_level0 = true;
@@ -967,4 +1000,12 @@ void IndexHNSWCagra::search(
967
1000
  }
968
1001
  }
969
1002
 
1003
+ faiss::NumericType IndexHNSWCagra::get_numeric_type() const {
1004
+ return numeric_type_;
1005
+ }
1006
+
1007
+ void IndexHNSWCagra::set_numeric_type(faiss::NumericType numeric_type) {
1008
+ numeric_type_ = numeric_type;
1009
+ }
1010
+
970
1011
  } // namespace faiss
@@ -10,6 +10,7 @@
10
10
  #pragma once
11
11
 
12
12
  #include <vector>
13
+ #include "faiss/Index.h"
13
14
 
14
15
  #include <faiss/IndexFlat.h>
15
16
  #include <faiss/IndexPQ.h>
@@ -170,7 +171,11 @@ struct IndexHNSW2Level : IndexHNSW {
170
171
 
171
172
  struct IndexHNSWCagra : IndexHNSW {
172
173
  IndexHNSWCagra();
173
- IndexHNSWCagra(int d, int M, MetricType metric = METRIC_L2);
174
+ IndexHNSWCagra(
175
+ int d,
176
+ int M,
177
+ MetricType metric = METRIC_L2,
178
+ NumericType numeric_type = NumericType::Float32);
174
179
 
175
180
  /// When set to true, the index is immutable.
176
181
  /// This option is used to copy the knn graph from GpuIndexCagra
@@ -195,6 +200,10 @@ struct IndexHNSWCagra : IndexHNSW {
195
200
  float* distances,
196
201
  idx_t* labels,
197
202
  const SearchParameters* params = nullptr) const override;
203
+
204
+ faiss::NumericType get_numeric_type() const;
205
+ void set_numeric_type(faiss::NumericType numeric_type);
206
+ NumericType numeric_type_;
198
207
  };
199
208
 
200
209
  } // namespace faiss