faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -5,17 +5,14 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #pragma once
10
9
 
11
10
  #include <faiss/IndexPQ.h>
12
11
  #include <faiss/impl/ProductQuantizer.h>
13
12
  #include <faiss/utils/AlignedTable.h>
14
13
 
15
-
16
14
  namespace faiss {
17
15
 
18
-
19
16
  /** Fast scan version of IndexPQ. Works for 4-bit PQ for now.
20
17
  *
21
18
  * The codes are not stored sequentially but grouped in blocks of size bbs.
@@ -28,7 +25,7 @@ namespace faiss {
28
25
  * 15: no qbs with reservoir accumulator
29
26
  */
30
27
 
31
- struct IndexPQFastScan: Index {
28
+ struct IndexPQFastScan : Index {
32
29
  ProductQuantizer pq;
33
30
 
34
31
  // implementation to select
@@ -37,8 +34,8 @@ struct IndexPQFastScan: Index {
37
34
  int skip = 0;
38
35
 
39
36
  // size of the kernel
40
- int bbs; // set at build time
41
- int qbs = 0; // query block size 0 = use default
37
+ int bbs; // set at build time
38
+ int qbs = 0; // query block size 0 = use default
42
39
 
43
40
  // packed version of the codes
44
41
  size_t ntotal2;
@@ -47,22 +44,23 @@ struct IndexPQFastScan: Index {
47
44
  AlignedTable<uint8_t> codes;
48
45
 
49
46
  // this is for testing purposes only (set when initialized by IndexPQ)
50
- const uint8_t *orig_codes = nullptr;
47
+ const uint8_t* orig_codes = nullptr;
51
48
 
52
49
  IndexPQFastScan(
53
- int d, size_t M, size_t nbits,
54
- MetricType metric = METRIC_L2,
55
- int bbs = 32
56
- );
50
+ int d,
51
+ size_t M,
52
+ size_t nbits,
53
+ MetricType metric = METRIC_L2,
54
+ int bbs = 32);
57
55
 
58
56
  IndexPQFastScan();
59
57
 
60
58
  /// build from an existing IndexPQ
61
- explicit IndexPQFastScan(const IndexPQ & orig, int bbs = 32);
59
+ explicit IndexPQFastScan(const IndexPQ& orig, int bbs = 32);
62
60
 
63
- void train (idx_t n, const float *x) override;
64
- void add (idx_t n, const float *x) override;
65
- void reset() override ;
61
+ void train(idx_t n, const float* x) override;
62
+ void add(idx_t n, const float* x) override;
63
+ void reset() override;
66
64
  void search(
67
65
  idx_t n,
68
66
  const float* x,
@@ -72,35 +70,51 @@ struct IndexPQFastScan: Index {
72
70
 
73
71
  // called by search function
74
72
  void compute_quantized_LUT(
75
- idx_t n, const float* x,
76
- uint8_t *lut, float *normalizers) const ;
73
+ idx_t n,
74
+ const float* x,
75
+ uint8_t* lut,
76
+ float* normalizers) const;
77
77
 
78
- template<bool is_max>
78
+ template <bool is_max>
79
79
  void search_dispatch_implem(
80
- idx_t n, const float* x, idx_t k,
81
- float* distances, idx_t* labels) const;
80
+ idx_t n,
81
+ const float* x,
82
+ idx_t k,
83
+ float* distances,
84
+ idx_t* labels) const;
82
85
 
83
- template<class C>
86
+ template <class C>
84
87
  void search_implem_2(
85
- idx_t n, const float* x, idx_t k,
86
- float* distances, idx_t* labels) const;
87
-
88
+ idx_t n,
89
+ const float* x,
90
+ idx_t k,
91
+ float* distances,
92
+ idx_t* labels) const;
88
93
 
89
- template<class C>
94
+ template <class C>
90
95
  void search_implem_12(
91
- idx_t n, const float* x, idx_t k,
92
- float* distances, idx_t* labels, int impl) const;
96
+ idx_t n,
97
+ const float* x,
98
+ idx_t k,
99
+ float* distances,
100
+ idx_t* labels,
101
+ int impl) const;
93
102
 
94
- template<class C>
103
+ template <class C>
95
104
  void search_implem_14(
96
- idx_t n, const float* x, idx_t k,
97
- float* distances, idx_t* labels, int impl) const;
98
-
105
+ idx_t n,
106
+ const float* x,
107
+ idx_t k,
108
+ float* distances,
109
+ idx_t* labels,
110
+ int impl) const;
99
111
  };
100
112
 
101
113
  struct FastScanStats {
102
114
  uint64_t t0, t1, t2, t3;
103
- FastScanStats() {reset();}
115
+ FastScanStats() {
116
+ reset();
117
+ }
104
118
  void reset() {
105
119
  memset(this, 0, sizeof(*this));
106
120
  }
@@ -9,13 +9,13 @@
9
9
 
10
10
  #include <faiss/IndexPreTransform.h>
11
11
 
12
- #include <cstdio>
13
12
  #include <cmath>
13
+ #include <cstdio>
14
14
  #include <cstring>
15
15
  #include <memory>
16
16
 
17
- #include <faiss/impl/FaissAssert.h>
18
17
  #include <faiss/impl/AuxIndexStructures.h>
18
+ #include <faiss/impl/FaissAssert.h>
19
19
 
20
20
  namespace faiss {
21
21
 
@@ -23,44 +23,29 @@ namespace faiss {
23
23
  * IndexPreTransform
24
24
  *********************************************/
25
25
 
26
- IndexPreTransform::IndexPreTransform ():
27
- index(nullptr), own_fields (false)
28
- {
29
- }
30
-
26
+ IndexPreTransform::IndexPreTransform() : index(nullptr), own_fields(false) {}
31
27
 
32
- IndexPreTransform::IndexPreTransform (
33
- Index * index):
34
- Index (index->d, index->metric_type),
35
- index (index), own_fields (false)
36
- {
28
+ IndexPreTransform::IndexPreTransform(Index* index)
29
+ : Index(index->d, index->metric_type), index(index), own_fields(false) {
37
30
  is_trained = index->is_trained;
38
31
  ntotal = index->ntotal;
39
32
  }
40
33
 
41
-
42
- IndexPreTransform::IndexPreTransform (
43
- VectorTransform * ltrans,
44
- Index * index):
45
- Index (index->d, index->metric_type),
46
- index (index), own_fields (false)
47
- {
34
+ IndexPreTransform::IndexPreTransform(VectorTransform* ltrans, Index* index)
35
+ : Index(index->d, index->metric_type), index(index), own_fields(false) {
48
36
  is_trained = index->is_trained;
49
37
  ntotal = index->ntotal;
50
- prepend_transform (ltrans);
38
+ prepend_transform(ltrans);
51
39
  }
52
40
 
53
- void IndexPreTransform::prepend_transform (VectorTransform *ltrans)
54
- {
55
- FAISS_THROW_IF_NOT (ltrans->d_out == d);
41
+ void IndexPreTransform::prepend_transform(VectorTransform* ltrans) {
42
+ FAISS_THROW_IF_NOT(ltrans->d_out == d);
56
43
  is_trained = is_trained && ltrans->is_trained;
57
- chain.insert (chain.begin(), ltrans);
44
+ chain.insert(chain.begin(), ltrans);
58
45
  d = ltrans->d_in;
59
46
  }
60
47
 
61
-
62
- IndexPreTransform::~IndexPreTransform ()
63
- {
48
+ IndexPreTransform::~IndexPreTransform() {
64
49
  if (own_fields) {
65
50
  for (int i = 0; i < chain.size(); i++)
66
51
  delete chain[i];
@@ -68,11 +53,7 @@ IndexPreTransform::~IndexPreTransform ()
68
53
  }
69
54
  }
70
55
 
71
-
72
-
73
-
74
- void IndexPreTransform::train (idx_t n, const float *x)
75
- {
56
+ void IndexPreTransform::train(idx_t n, const float* x) {
76
57
  int last_untrained = 0;
77
58
  if (!index->is_trained) {
78
59
  last_untrained = chain.size();
@@ -84,7 +65,7 @@ void IndexPreTransform::train (idx_t n, const float *x)
84
65
  }
85
66
  }
86
67
  }
87
- const float *prev_x = x;
68
+ const float* prev_x = x;
88
69
  ScopeDeleter<float> del;
89
70
 
90
71
  if (verbose) {
@@ -93,34 +74,35 @@ void IndexPreTransform::train (idx_t n, const float *x)
93
74
  }
94
75
 
95
76
  for (int i = 0; i <= last_untrained; i++) {
96
-
97
77
  if (i < chain.size()) {
98
- VectorTransform *ltrans = chain [i];
78
+ VectorTransform* ltrans = chain[i];
99
79
  if (!ltrans->is_trained) {
100
80
  if (verbose) {
101
81
  printf(" Training chain component %d/%zd\n",
102
- i, chain.size());
103
- if (OPQMatrix *opqm = dynamic_cast<OPQMatrix*>(ltrans)) {
82
+ i,
83
+ chain.size());
84
+ if (OPQMatrix* opqm = dynamic_cast<OPQMatrix*>(ltrans)) {
104
85
  opqm->verbose = true;
105
86
  }
106
87
  }
107
- ltrans->train (n, prev_x);
88
+ ltrans->train(n, prev_x);
108
89
  }
109
90
  } else {
110
91
  if (verbose) {
111
92
  printf(" Training sub-index\n");
112
93
  }
113
- index->train (n, prev_x);
94
+ index->train(n, prev_x);
114
95
  }
115
- if (i == last_untrained) break;
96
+ if (i == last_untrained)
97
+ break;
116
98
  if (verbose) {
117
- printf(" Applying transform %d/%zd\n",
118
- i, chain.size());
99
+ printf(" Applying transform %d/%zd\n", i, chain.size());
119
100
  }
120
101
 
121
- float * xt = chain[i]->apply (n, prev_x);
102
+ float* xt = chain[i]->apply(n, prev_x);
122
103
 
123
- if (prev_x != x) delete [] prev_x;
104
+ if (prev_x != x)
105
+ delete[] prev_x;
124
106
  prev_x = xt;
125
107
  del.set(xt);
126
108
  }
@@ -128,200 +110,190 @@ void IndexPreTransform::train (idx_t n, const float *x)
128
110
  is_trained = true;
129
111
  }
130
112
 
131
-
132
- const float *IndexPreTransform::apply_chain (idx_t n, const float *x) const
133
- {
134
- const float *prev_x = x;
113
+ const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
114
+ const float* prev_x = x;
135
115
  ScopeDeleter<float> del;
136
116
 
137
117
  for (int i = 0; i < chain.size(); i++) {
138
- float * xt = chain[i]->apply (n, prev_x);
139
- ScopeDeleter<float> del2 (xt);
140
- del2.swap (del);
118
+ float* xt = chain[i]->apply(n, prev_x);
119
+ ScopeDeleter<float> del2(xt);
120
+ del2.swap(del);
141
121
  prev_x = xt;
142
122
  }
143
- del.release ();
123
+ del.release();
144
124
  return prev_x;
145
125
  }
146
126
 
147
- void IndexPreTransform::reverse_chain (idx_t n, const float* xt, float* x) const
148
- {
127
+ void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
128
+ const {
149
129
  const float* next_x = xt;
150
130
  ScopeDeleter<float> del;
151
131
 
152
132
  for (int i = chain.size() - 1; i >= 0; i--) {
153
- float* prev_x = (i == 0) ? x : new float [n * chain[i]->d_in];
154
- ScopeDeleter<float> del2 ((prev_x == x) ? nullptr : prev_x);
155
- chain [i]->reverse_transform (n, next_x, prev_x);
156
- del2.swap (del);
133
+ float* prev_x = (i == 0) ? x : new float[n * chain[i]->d_in];
134
+ ScopeDeleter<float> del2((prev_x == x) ? nullptr : prev_x);
135
+ chain[i]->reverse_transform(n, next_x, prev_x);
136
+ del2.swap(del);
157
137
  next_x = prev_x;
158
138
  }
159
139
  }
160
140
 
161
- void IndexPreTransform::add (idx_t n, const float *x)
162
- {
163
- FAISS_THROW_IF_NOT (is_trained);
164
- const float *xt = apply_chain (n, x);
141
+ void IndexPreTransform::add(idx_t n, const float* x) {
142
+ FAISS_THROW_IF_NOT(is_trained);
143
+ const float* xt = apply_chain(n, x);
165
144
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
166
- index->add (n, xt);
145
+ index->add(n, xt);
167
146
  ntotal = index->ntotal;
168
147
  }
169
148
 
170
- void IndexPreTransform::add_with_ids (idx_t n, const float * x,
171
- const idx_t *xids)
172
- {
173
- FAISS_THROW_IF_NOT (is_trained);
174
- const float *xt = apply_chain (n, x);
149
+ void IndexPreTransform::add_with_ids(
150
+ idx_t n,
151
+ const float* x,
152
+ const idx_t* xids) {
153
+ FAISS_THROW_IF_NOT(is_trained);
154
+ const float* xt = apply_chain(n, x);
175
155
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
176
- index->add_with_ids (n, xt, xids);
156
+ index->add_with_ids(n, xt, xids);
177
157
  ntotal = index->ntotal;
178
158
  }
179
159
 
160
+ void IndexPreTransform::search(
161
+ idx_t n,
162
+ const float* x,
163
+ idx_t k,
164
+ float* distances,
165
+ idx_t* labels) const {
166
+ FAISS_THROW_IF_NOT(k > 0);
180
167
 
181
-
182
-
183
- void IndexPreTransform::search (idx_t n, const float *x, idx_t k,
184
- float *distances, idx_t *labels) const
185
- {
186
- FAISS_THROW_IF_NOT (is_trained);
187
- const float *xt = apply_chain (n, x);
168
+ FAISS_THROW_IF_NOT(is_trained);
169
+ const float* xt = apply_chain(n, x);
188
170
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
189
- index->search (n, xt, k, distances, labels);
171
+ index->search(n, xt, k, distances, labels);
190
172
  }
191
173
 
192
- void IndexPreTransform::range_search (idx_t n, const float* x, float radius,
193
- RangeSearchResult* result) const
194
- {
195
- FAISS_THROW_IF_NOT (is_trained);
196
- const float *xt = apply_chain (n, x);
174
+ void IndexPreTransform::range_search(
175
+ idx_t n,
176
+ const float* x,
177
+ float radius,
178
+ RangeSearchResult* result) const {
179
+ FAISS_THROW_IF_NOT(is_trained);
180
+ const float* xt = apply_chain(n, x);
197
181
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
198
- index->range_search (n, xt, radius, result);
182
+ index->range_search(n, xt, radius, result);
199
183
  }
200
184
 
201
-
202
-
203
- void IndexPreTransform::reset () {
185
+ void IndexPreTransform::reset() {
204
186
  index->reset();
205
187
  ntotal = 0;
206
188
  }
207
189
 
208
- size_t IndexPreTransform::remove_ids (const IDSelector & sel) {
209
- size_t nremove = index->remove_ids (sel);
190
+ size_t IndexPreTransform::remove_ids(const IDSelector& sel) {
191
+ size_t nremove = index->remove_ids(sel);
210
192
  ntotal = index->ntotal;
211
193
  return nremove;
212
194
  }
213
195
 
214
-
215
- void IndexPreTransform::reconstruct (idx_t key, float * recons) const
216
- {
217
- float *x = chain.empty() ? recons : new float [index->d];
218
- ScopeDeleter<float> del (recons == x ? nullptr : x);
196
+ void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
197
+ float* x = chain.empty() ? recons : new float[index->d];
198
+ ScopeDeleter<float> del(recons == x ? nullptr : x);
219
199
  // Initial reconstruction
220
- index->reconstruct (key, x);
200
+ index->reconstruct(key, x);
221
201
 
222
202
  // Revert transformations from last to first
223
- reverse_chain (1, x, recons);
203
+ reverse_chain(1, x, recons);
224
204
  }
225
205
 
226
-
227
- void IndexPreTransform::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
228
- {
229
- float *x = chain.empty() ? recons : new float [ni * index->d];
230
- ScopeDeleter<float> del (recons == x ? nullptr : x);
206
+ void IndexPreTransform::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
207
+ float* x = chain.empty() ? recons : new float[ni * index->d];
208
+ ScopeDeleter<float> del(recons == x ? nullptr : x);
231
209
  // Initial reconstruction
232
- index->reconstruct_n (i0, ni, x);
210
+ index->reconstruct_n(i0, ni, x);
233
211
 
234
212
  // Revert transformations from last to first
235
- reverse_chain (ni, x, recons);
213
+ reverse_chain(ni, x, recons);
236
214
  }
237
215
 
216
+ void IndexPreTransform::search_and_reconstruct(
217
+ idx_t n,
218
+ const float* x,
219
+ idx_t k,
220
+ float* distances,
221
+ idx_t* labels,
222
+ float* recons) const {
223
+ FAISS_THROW_IF_NOT(k > 0);
238
224
 
239
- void IndexPreTransform::search_and_reconstruct (
240
- idx_t n, const float *x, idx_t k,
241
- float *distances, idx_t *labels, float* recons) const
242
- {
243
- FAISS_THROW_IF_NOT (is_trained);
225
+ FAISS_THROW_IF_NOT(is_trained);
244
226
 
245
- const float* xt = apply_chain (n, x);
246
- ScopeDeleter<float> del ((xt == x) ? nullptr : xt);
227
+ const float* xt = apply_chain(n, x);
228
+ ScopeDeleter<float> del((xt == x) ? nullptr : xt);
247
229
 
248
- float* recons_temp = chain.empty() ? recons : new float [n * k * index->d];
249
- ScopeDeleter<float> del2 ((recons_temp == recons) ? nullptr : recons_temp);
250
- index->search_and_reconstruct (n, xt, k, distances, labels, recons_temp);
230
+ float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
231
+ ScopeDeleter<float> del2((recons_temp == recons) ? nullptr : recons_temp);
232
+ index->search_and_reconstruct(n, xt, k, distances, labels, recons_temp);
251
233
 
252
234
  // Revert transformations from last to first
253
- reverse_chain (n * k, recons_temp, recons);
235
+ reverse_chain(n * k, recons_temp, recons);
254
236
  }
255
237
 
256
- size_t IndexPreTransform::sa_code_size () const
257
- {
258
- return index->sa_code_size ();
238
+ size_t IndexPreTransform::sa_code_size() const {
239
+ return index->sa_code_size();
259
240
  }
260
241
 
261
- void IndexPreTransform::sa_encode (idx_t n, const float *x,
262
- uint8_t *bytes) const
263
- {
242
+ void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes)
243
+ const {
264
244
  if (chain.empty()) {
265
- index->sa_encode (n, x, bytes);
245
+ index->sa_encode(n, x, bytes);
266
246
  } else {
267
- const float *xt = apply_chain (n, x);
247
+ const float* xt = apply_chain(n, x);
268
248
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
269
- index->sa_encode (n, xt, bytes);
249
+ index->sa_encode(n, xt, bytes);
270
250
  }
271
251
  }
272
252
 
273
- void IndexPreTransform::sa_decode (idx_t n, const uint8_t *bytes,
274
- float *x) const
275
- {
253
+ void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)
254
+ const {
276
255
  if (chain.empty()) {
277
- index->sa_decode (n, bytes, x);
256
+ index->sa_decode(n, bytes, x);
278
257
  } else {
279
- std::unique_ptr<float []> x1 (new float [index->d * n]);
280
- index->sa_decode (n, bytes, x1.get());
258
+ std::unique_ptr<float[]> x1(new float[index->d * n]);
259
+ index->sa_decode(n, bytes, x1.get());
281
260
  // Revert transformations from last to first
282
- reverse_chain (n, x1.get(), x);
261
+ reverse_chain(n, x1.get(), x);
283
262
  }
284
263
  }
285
264
 
286
265
  namespace {
287
266
 
288
- struct PreTransformDistanceComputer: DistanceComputer {
289
- const IndexPreTransform *index;
267
+ struct PreTransformDistanceComputer : DistanceComputer {
268
+ const IndexPreTransform* index;
290
269
  std::unique_ptr<DistanceComputer> sub_dc;
291
- std::unique_ptr<const float []> query;
270
+ std::unique_ptr<const float[]> query;
292
271
 
293
- explicit PreTransformDistanceComputer(const IndexPreTransform *index):
294
- index(index),
295
- sub_dc(index->index->get_distance_computer())
296
- {}
272
+ explicit PreTransformDistanceComputer(const IndexPreTransform* index)
273
+ : index(index), sub_dc(index->index->get_distance_computer()) {}
297
274
 
298
- void set_query(const float *x) override {
299
- const float *xt = index->apply_chain (1, x);
275
+ void set_query(const float* x) override {
276
+ const float* xt = index->apply_chain(1, x);
300
277
  if (xt == x) {
301
- sub_dc->set_query (x);
278
+ sub_dc->set_query(x);
302
279
  } else {
303
280
  query.reset(xt);
304
- sub_dc->set_query (xt);
281
+ sub_dc->set_query(xt);
305
282
  }
306
283
  }
307
284
 
308
- float symmetric_dis(idx_t i, idx_t j) override
309
- {
285
+ float symmetric_dis(idx_t i, idx_t j) override {
310
286
  return sub_dc->symmetric_dis(i, j);
311
287
  }
312
288
 
313
- float operator () (idx_t i) override
314
- {
289
+ float operator()(idx_t i) override {
315
290
  return (*sub_dc)(i);
316
291
  }
317
-
318
292
  };
319
293
 
320
-
321
294
  } // anonymous namespace
322
295
 
323
-
324
- DistanceComputer * IndexPreTransform::get_distance_computer() const {
296
+ DistanceComputer* IndexPreTransform::get_distance_computer() const {
325
297
  if (chain.empty()) {
326
298
  return index->get_distance_computer();
327
299
  } else {
@@ -329,6 +301,4 @@ DistanceComputer * IndexPreTransform::get_distance_computer() const {
329
301
  }
330
302
  }
331
303
 
332
-
333
-
334
304
  } // namespace faiss