faiss 0.2.0 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -10,10 +10,10 @@
10
10
  #include <faiss/utils/distances.h>
11
11
 
12
12
  #include <algorithm>
13
- #include <cstdio>
14
13
  #include <cassert>
15
- #include <cstring>
16
14
  #include <cmath>
15
+ #include <cstdio>
16
+ #include <cstring>
17
17
 
18
18
  #include <omp.h>
19
19
 
@@ -21,186 +21,153 @@
21
21
  #include <faiss/impl/FaissAssert.h>
22
22
  #include <faiss/impl/ResultHandler.h>
23
23
 
24
-
25
-
26
24
  #ifndef FINTEGER
27
25
  #define FINTEGER long
28
26
  #endif
29
27
 
30
-
31
28
  extern "C" {
32
29
 
33
30
  /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
34
31
 
35
- int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
36
- n, FINTEGER *k, const float *alpha, const float *a,
37
- FINTEGER *lda, const float *b, FINTEGER *
38
- ldb, float *beta, float *c, FINTEGER *ldc);
39
-
40
-
32
+ int sgemm_(
33
+ const char* transa,
34
+ const char* transb,
35
+ FINTEGER* m,
36
+ FINTEGER* n,
37
+ FINTEGER* k,
38
+ const float* alpha,
39
+ const float* a,
40
+ FINTEGER* lda,
41
+ const float* b,
42
+ FINTEGER* ldb,
43
+ float* beta,
44
+ float* c,
45
+ FINTEGER* ldc);
41
46
  }
42
47
 
43
-
44
48
  namespace faiss {
45
49
 
46
-
47
-
48
50
  /***************************************************************************
49
51
  * Matrix/vector ops
50
52
  ***************************************************************************/
51
53
 
52
-
53
-
54
-
55
54
  /* Compute the L2 norm of a set of nx vectors */
56
- void fvec_norms_L2 (float * __restrict nr,
57
- const float * __restrict x,
58
- size_t d, size_t nx)
59
- {
60
-
55
+ void fvec_norms_L2(
56
+ float* __restrict nr,
57
+ const float* __restrict x,
58
+ size_t d,
59
+ size_t nx) {
61
60
  #pragma omp parallel for
62
61
  for (int64_t i = 0; i < nx; i++) {
63
- nr[i] = sqrtf (fvec_norm_L2sqr (x + i * d, d));
62
+ nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
64
63
  }
65
64
  }
66
65
 
67
- void fvec_norms_L2sqr (float * __restrict nr,
68
- const float * __restrict x,
69
- size_t d, size_t nx)
70
- {
66
+ void fvec_norms_L2sqr(
67
+ float* __restrict nr,
68
+ const float* __restrict x,
69
+ size_t d,
70
+ size_t nx) {
71
71
  #pragma omp parallel for
72
72
  for (int64_t i = 0; i < nx; i++)
73
- nr[i] = fvec_norm_L2sqr (x + i * d, d);
73
+ nr[i] = fvec_norm_L2sqr(x + i * d, d);
74
74
  }
75
75
 
76
-
77
-
78
- void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
79
- {
76
+ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
80
77
  #pragma omp parallel for
81
78
  for (int64_t i = 0; i < nx; i++) {
82
- float * __restrict xi = x + i * d;
79
+ float* __restrict xi = x + i * d;
83
80
 
84
- float nr = fvec_norm_L2sqr (xi, d);
81
+ float nr = fvec_norm_L2sqr(xi, d);
85
82
 
86
83
  if (nr > 0) {
87
84
  size_t j;
88
- const float inv_nr = 1.0 / sqrtf (nr);
85
+ const float inv_nr = 1.0 / sqrtf(nr);
89
86
  for (j = 0; j < d; j++)
90
87
  xi[j] *= inv_nr;
91
88
  }
92
89
  }
93
90
  }
94
91
 
95
-
96
-
97
-
98
-
99
-
100
-
101
-
102
-
103
-
104
-
105
-
106
92
  /***************************************************************************
107
93
  * KNN functions
108
94
  ***************************************************************************/
109
95
 
110
96
  namespace {
111
97
 
112
-
113
-
114
98
  /* Find the nearest neighbors for nx queries in a set of ny vectors */
115
- template<class ResultHandler>
116
- void exhaustive_inner_product_seq (
117
- const float * x,
118
- const float * y,
119
- size_t d, size_t nx, size_t ny,
120
- ResultHandler &res)
121
- {
122
- size_t check_period = InterruptCallback::get_period_hint (ny * d);
123
-
124
- check_period *= omp_get_max_threads();
125
-
99
+ template <class ResultHandler>
100
+ void exhaustive_inner_product_seq(
101
+ const float* x,
102
+ const float* y,
103
+ size_t d,
104
+ size_t nx,
105
+ size_t ny,
106
+ ResultHandler& res) {
126
107
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
108
+ int nt = std::min(int(nx), omp_get_max_threads());
127
109
 
128
- for (size_t i0 = 0; i0 < nx; i0 += check_period) {
129
- size_t i1 = std::min(i0 + check_period, nx);
130
-
131
- #pragma omp parallel
132
- {
133
- SingleResultHandler resi(res);
110
+ #pragma omp parallel num_threads(nt)
111
+ {
112
+ SingleResultHandler resi(res);
134
113
  #pragma omp for
135
- for (int64_t i = i0; i < i1; i++) {
136
- const float * x_i = x + i * d;
137
- const float * y_j = y;
114
+ for (int64_t i = 0; i < nx; i++) {
115
+ const float* x_i = x + i * d;
116
+ const float* y_j = y;
138
117
 
139
- resi.begin(i);
118
+ resi.begin(i);
140
119
 
141
- for (size_t j = 0; j < ny; j++) {
142
- float ip = fvec_inner_product (x_i, y_j, d);
143
- resi.add_result(ip, j);
144
- y_j += d;
145
- }
146
- resi.end();
120
+ for (size_t j = 0; j < ny; j++) {
121
+ float ip = fvec_inner_product(x_i, y_j, d);
122
+ resi.add_result(ip, j);
123
+ y_j += d;
147
124
  }
125
+ resi.end();
148
126
  }
149
- InterruptCallback::check ();
150
127
  }
151
-
152
128
  }
153
129
 
154
- template<class ResultHandler>
155
- void exhaustive_L2sqr_seq (
156
- const float * x,
157
- const float * y,
158
- size_t d, size_t nx, size_t ny,
159
- ResultHandler & res)
160
- {
161
-
162
- size_t check_period = InterruptCallback::get_period_hint (ny * d);
163
- check_period *= omp_get_max_threads();
130
+ template <class ResultHandler>
131
+ void exhaustive_L2sqr_seq(
132
+ const float* x,
133
+ const float* y,
134
+ size_t d,
135
+ size_t nx,
136
+ size_t ny,
137
+ ResultHandler& res) {
164
138
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
139
+ int nt = std::min(int(nx), omp_get_max_threads());
165
140
 
166
- for (size_t i0 = 0; i0 < nx; i0 += check_period) {
167
- size_t i1 = std::min(i0 + check_period, nx);
168
-
169
- #pragma omp parallel
170
- {
171
- SingleResultHandler resi(res);
141
+ #pragma omp parallel num_threads(nt)
142
+ {
143
+ SingleResultHandler resi(res);
172
144
  #pragma omp for
173
- for (int64_t i = i0; i < i1; i++) {
174
- const float * x_i = x + i * d;
175
- const float * y_j = y;
176
- resi.begin(i);
177
- for (size_t j = 0; j < ny; j++) {
178
- float disij = fvec_L2sqr (x_i, y_j, d);
179
- resi.add_result(disij, j);
180
- y_j += d;
181
- }
182
- resi.end();
145
+ for (int64_t i = 0; i < nx; i++) {
146
+ const float* x_i = x + i * d;
147
+ const float* y_j = y;
148
+ resi.begin(i);
149
+ for (size_t j = 0; j < ny; j++) {
150
+ float disij = fvec_L2sqr(x_i, y_j, d);
151
+ resi.add_result(disij, j);
152
+ y_j += d;
183
153
  }
154
+ resi.end();
184
155
  }
185
- InterruptCallback::check ();
186
156
  }
187
-
188
- };
189
-
190
-
191
-
192
-
157
+ }
193
158
 
194
159
  /** Find the nearest neighbors for nx queries in a set of ny vectors */
195
- template<class ResultHandler>
196
- void exhaustive_inner_product_blas (
197
- const float * x,
198
- const float * y,
199
- size_t d, size_t nx, size_t ny,
200
- ResultHandler & res)
201
- {
160
+ template <class ResultHandler>
161
+ void exhaustive_inner_product_blas(
162
+ const float* x,
163
+ const float* y,
164
+ size_t d,
165
+ size_t nx,
166
+ size_t ny,
167
+ ResultHandler& res) {
202
168
  // BLAS does not like empty matrices
203
- if (nx == 0 || ny == 0) return;
169
+ if (nx == 0 || ny == 0)
170
+ return;
204
171
 
205
172
  /* block sizes */
206
173
  const size_t bs_x = distance_compute_blas_query_bs;
@@ -209,86 +176,105 @@ void exhaustive_inner_product_blas (
209
176
 
210
177
  for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
211
178
  size_t i1 = i0 + bs_x;
212
- if(i1 > nx) i1 = nx;
179
+ if (i1 > nx)
180
+ i1 = nx;
213
181
 
214
182
  res.begin_multiple(i0, i1);
215
183
 
216
184
  for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
217
185
  size_t j1 = j0 + bs_y;
218
- if (j1 > ny) j1 = ny;
186
+ if (j1 > ny)
187
+ j1 = ny;
219
188
  /* compute the actual dot products */
220
189
  {
221
190
  float one = 1, zero = 0;
222
191
  FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
223
- sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
224
- y + j0 * d, &di,
225
- x + i0 * d, &di, &zero,
226
- ip_block.get(), &nyi);
192
+ sgemm_("Transpose",
193
+ "Not transpose",
194
+ &nyi,
195
+ &nxi,
196
+ &di,
197
+ &one,
198
+ y + j0 * d,
199
+ &di,
200
+ x + i0 * d,
201
+ &di,
202
+ &zero,
203
+ ip_block.get(),
204
+ &nyi);
227
205
  }
228
206
 
229
207
  res.add_results(j0, j1, ip_block.get());
230
-
231
208
  }
232
209
  res.end_multiple();
233
- InterruptCallback::check ();
234
-
210
+ InterruptCallback::check();
235
211
  }
236
212
  }
237
213
 
238
-
239
-
240
-
241
214
  // distance correction is an operator that can be applied to transform
242
215
  // the distances
243
- template<class ResultHandler>
244
- void exhaustive_L2sqr_blas (
245
- const float * x,
246
- const float * y,
247
- size_t d, size_t nx, size_t ny,
248
- ResultHandler & res,
249
- const float *y_norms = nullptr)
250
- {
216
+ template <class ResultHandler>
217
+ void exhaustive_L2sqr_blas(
218
+ const float* x,
219
+ const float* y,
220
+ size_t d,
221
+ size_t nx,
222
+ size_t ny,
223
+ ResultHandler& res,
224
+ const float* y_norms = nullptr) {
251
225
  // BLAS does not like empty matrices
252
- if (nx == 0 || ny == 0) return;
226
+ if (nx == 0 || ny == 0)
227
+ return;
253
228
 
254
229
  /* block sizes */
255
230
  const size_t bs_x = distance_compute_blas_query_bs;
256
231
  const size_t bs_y = distance_compute_blas_database_bs;
257
232
  // const size_t bs_x = 16, bs_y = 16;
258
- std::unique_ptr<float []> ip_block(new float[bs_x * bs_y]);
259
- std::unique_ptr<float []> x_norms(new float[nx]);
260
- std::unique_ptr<float []> del2;
233
+ std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
234
+ std::unique_ptr<float[]> x_norms(new float[nx]);
235
+ std::unique_ptr<float[]> del2;
261
236
 
262
- fvec_norms_L2sqr (x_norms.get(), x, d, nx);
237
+ fvec_norms_L2sqr(x_norms.get(), x, d, nx);
263
238
 
264
239
  if (!y_norms) {
265
- float *y_norms2 = new float[ny];
240
+ float* y_norms2 = new float[ny];
266
241
  del2.reset(y_norms2);
267
- fvec_norms_L2sqr (y_norms2, y, d, ny);
242
+ fvec_norms_L2sqr(y_norms2, y, d, ny);
268
243
  y_norms = y_norms2;
269
244
  }
270
245
 
271
246
  for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
272
247
  size_t i1 = i0 + bs_x;
273
- if(i1 > nx) i1 = nx;
248
+ if (i1 > nx)
249
+ i1 = nx;
274
250
 
275
251
  res.begin_multiple(i0, i1);
276
252
 
277
253
  for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
278
254
  size_t j1 = j0 + bs_y;
279
- if (j1 > ny) j1 = ny;
255
+ if (j1 > ny)
256
+ j1 = ny;
280
257
  /* compute the actual dot products */
281
258
  {
282
259
  float one = 1, zero = 0;
283
260
  FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
284
- sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
285
- y + j0 * d, &di,
286
- x + i0 * d, &di, &zero,
287
- ip_block.get(), &nyi);
261
+ sgemm_("Transpose",
262
+ "Not transpose",
263
+ &nyi,
264
+ &nxi,
265
+ &di,
266
+ &one,
267
+ y + j0 * d,
268
+ &di,
269
+ x + i0 * d,
270
+ &di,
271
+ &zero,
272
+ ip_block.get(),
273
+ &nyi);
288
274
  }
289
-
275
+ #pragma omp parallel for
290
276
  for (int64_t i = i0; i < i1; i++) {
291
- float *ip_line = ip_block.get() + (i - i0) * (j1 - j0);
277
+ float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
292
278
 
293
279
  for (size_t j = j0; j < j1; j++) {
294
280
  float ip = *ip_line;
@@ -296,7 +282,8 @@ void exhaustive_L2sqr_blas (
296
282
 
297
283
  // negative values can occur for identical vectors
298
284
  // due to roundoff errors
299
- if (dis < 0) dis = 0;
285
+ if (dis < 0)
286
+ dis = 0;
300
287
 
301
288
  *ip_line = dis;
302
289
  ip_line++;
@@ -305,18 +292,12 @@ void exhaustive_L2sqr_blas (
305
292
  res.add_results(j0, j1, ip_block.get());
306
293
  }
307
294
  res.end_multiple();
308
- InterruptCallback::check ();
295
+ InterruptCallback::check();
309
296
  }
310
297
  }
311
298
 
312
-
313
-
314
299
  } // anonymous namespace
315
300
 
316
-
317
-
318
-
319
-
320
301
  /*******************************************************
321
302
  * KNN driver functions
322
303
  *******************************************************/
@@ -326,268 +307,275 @@ int distance_compute_blas_query_bs = 4096;
326
307
  int distance_compute_blas_database_bs = 1024;
327
308
  int distance_compute_min_k_reservoir = 100;
328
309
 
329
- void knn_inner_product (const float * x,
330
- const float * y,
331
- size_t d, size_t nx, size_t ny,
332
- float_minheap_array_t * ha)
333
- {
310
+ void knn_inner_product(
311
+ const float* x,
312
+ const float* y,
313
+ size_t d,
314
+ size_t nx,
315
+ size_t ny,
316
+ float_minheap_array_t* ha) {
334
317
  if (ha->k < distance_compute_min_k_reservoir) {
335
318
  HeapResultHandler<CMin<float, int64_t>> res(
336
- ha->nh, ha->val, ha->ids, ha->k);
319
+ ha->nh, ha->val, ha->ids, ha->k);
337
320
  if (nx < distance_compute_blas_threshold) {
338
- exhaustive_inner_product_seq (x, y, d, nx, ny, res);
321
+ exhaustive_inner_product_seq(x, y, d, nx, ny, res);
339
322
  } else {
340
- exhaustive_inner_product_blas (x, y, d, nx, ny, res);
323
+ exhaustive_inner_product_blas(x, y, d, nx, ny, res);
341
324
  }
342
325
  } else {
343
326
  ReservoirResultHandler<CMin<float, int64_t>> res(
344
- ha->nh, ha->val, ha->ids, ha->k);
327
+ ha->nh, ha->val, ha->ids, ha->k);
345
328
  if (nx < distance_compute_blas_threshold) {
346
- exhaustive_inner_product_seq (x, y, d, nx, ny, res);
329
+ exhaustive_inner_product_seq(x, y, d, nx, ny, res);
347
330
  } else {
348
- exhaustive_inner_product_blas (x, y, d, nx, ny, res);
331
+ exhaustive_inner_product_blas(x, y, d, nx, ny, res);
349
332
  }
350
333
  }
351
334
  }
352
335
 
353
-
354
-
355
-
356
- void knn_L2sqr (
357
- const float * x,
358
- const float * y,
359
- size_t d, size_t nx, size_t ny,
360
- float_maxheap_array_t * ha,
361
- const float *y_norm2
362
- ) {
363
-
336
+ void knn_L2sqr(
337
+ const float* x,
338
+ const float* y,
339
+ size_t d,
340
+ size_t nx,
341
+ size_t ny,
342
+ float_maxheap_array_t* ha,
343
+ const float* y_norm2) {
364
344
  if (ha->k < distance_compute_min_k_reservoir) {
365
345
  HeapResultHandler<CMax<float, int64_t>> res(
366
- ha->nh, ha->val, ha->ids, ha->k);
346
+ ha->nh, ha->val, ha->ids, ha->k);
367
347
 
368
348
  if (nx < distance_compute_blas_threshold) {
369
- exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
349
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
370
350
  } else {
371
- exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
351
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
372
352
  }
373
353
  } else {
374
354
  ReservoirResultHandler<CMax<float, int64_t>> res(
375
- ha->nh, ha->val, ha->ids, ha->k);
355
+ ha->nh, ha->val, ha->ids, ha->k);
376
356
  if (nx < distance_compute_blas_threshold) {
377
- exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
357
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
378
358
  } else {
379
- exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
359
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
380
360
  }
381
361
  }
382
362
  }
383
363
 
384
-
385
364
  /***************************************************************************
386
365
  * Range search
387
366
  ***************************************************************************/
388
367
 
389
-
390
-
391
-
392
- void range_search_L2sqr (
393
- const float * x,
394
- const float * y,
395
- size_t d, size_t nx, size_t ny,
368
+ void range_search_L2sqr(
369
+ const float* x,
370
+ const float* y,
371
+ size_t d,
372
+ size_t nx,
373
+ size_t ny,
396
374
  float radius,
397
- RangeSearchResult *res)
398
- {
375
+ RangeSearchResult* res) {
399
376
  RangeSearchResultHandler<CMax<float, int64_t>> resh(res, radius);
400
377
  if (nx < distance_compute_blas_threshold) {
401
- exhaustive_L2sqr_seq (x, y, d, nx, ny, resh);
378
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, resh);
402
379
  } else {
403
- exhaustive_L2sqr_blas (x, y, d, nx, ny, resh);
380
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
404
381
  }
405
382
  }
406
383
 
407
- void range_search_inner_product (
408
- const float * x,
409
- const float * y,
410
- size_t d, size_t nx, size_t ny,
384
+ void range_search_inner_product(
385
+ const float* x,
386
+ const float* y,
387
+ size_t d,
388
+ size_t nx,
389
+ size_t ny,
411
390
  float radius,
412
- RangeSearchResult *res)
413
- {
414
-
391
+ RangeSearchResult* res) {
415
392
  RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
416
393
  if (nx < distance_compute_blas_threshold) {
417
- exhaustive_inner_product_seq (x, y, d, nx, ny, resh);
394
+ exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
418
395
  } else {
419
- exhaustive_inner_product_blas (x, y, d, nx, ny, resh);
396
+ exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
420
397
  }
421
398
  }
422
399
 
423
-
424
400
  /***************************************************************************
425
401
  * compute a subset of distances
426
402
  ***************************************************************************/
427
403
 
428
404
  /* compute the inner product between x and a subset y of ny vectors,
429
405
  whose indices are given by idy. */
430
- void fvec_inner_products_by_idx (float * __restrict ip,
431
- const float * x,
432
- const float * y,
433
- const int64_t * __restrict ids, /* for y vecs */
434
- size_t d, size_t nx, size_t ny)
435
- {
406
+ void fvec_inner_products_by_idx(
407
+ float* __restrict ip,
408
+ const float* x,
409
+ const float* y,
410
+ const int64_t* __restrict ids, /* for y vecs */
411
+ size_t d,
412
+ size_t nx,
413
+ size_t ny) {
436
414
  #pragma omp parallel for
437
415
  for (int64_t j = 0; j < nx; j++) {
438
- const int64_t * __restrict idsj = ids + j * ny;
439
- const float * xj = x + j * d;
440
- float * __restrict ipj = ip + j * ny;
416
+ const int64_t* __restrict idsj = ids + j * ny;
417
+ const float* xj = x + j * d;
418
+ float* __restrict ipj = ip + j * ny;
441
419
  for (size_t i = 0; i < ny; i++) {
442
420
  if (idsj[i] < 0)
443
421
  continue;
444
- ipj[i] = fvec_inner_product (xj, y + d * idsj[i], d);
422
+ ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
445
423
  }
446
424
  }
447
425
  }
448
426
 
449
-
450
-
451
427
  /* compute the inner product between x and a subset y of ny vectors,
452
428
  whose indices are given by idy. */
453
- void fvec_L2sqr_by_idx (float * __restrict dis,
454
- const float * x,
455
- const float * y,
456
- const int64_t * __restrict ids, /* ids of y vecs */
457
- size_t d, size_t nx, size_t ny)
458
- {
429
+ void fvec_L2sqr_by_idx(
430
+ float* __restrict dis,
431
+ const float* x,
432
+ const float* y,
433
+ const int64_t* __restrict ids, /* ids of y vecs */
434
+ size_t d,
435
+ size_t nx,
436
+ size_t ny) {
459
437
  #pragma omp parallel for
460
438
  for (int64_t j = 0; j < nx; j++) {
461
- const int64_t * __restrict idsj = ids + j * ny;
462
- const float * xj = x + j * d;
463
- float * __restrict disj = dis + j * ny;
439
+ const int64_t* __restrict idsj = ids + j * ny;
440
+ const float* xj = x + j * d;
441
+ float* __restrict disj = dis + j * ny;
464
442
  for (size_t i = 0; i < ny; i++) {
465
443
  if (idsj[i] < 0)
466
444
  continue;
467
- disj[i] = fvec_L2sqr (xj, y + d * idsj[i], d);
445
+ disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
468
446
  }
469
447
  }
470
448
  }
471
449
 
472
- void pairwise_indexed_L2sqr (
473
- size_t d, size_t n,
474
- const float * x, const int64_t *ix,
475
- const float * y, const int64_t *iy,
476
- float *dis)
477
- {
450
+ void pairwise_indexed_L2sqr(
451
+ size_t d,
452
+ size_t n,
453
+ const float* x,
454
+ const int64_t* ix,
455
+ const float* y,
456
+ const int64_t* iy,
457
+ float* dis) {
478
458
  #pragma omp parallel for
479
459
  for (int64_t j = 0; j < n; j++) {
480
460
  if (ix[j] >= 0 && iy[j] >= 0) {
481
- dis[j] = fvec_L2sqr (x + d * ix[j], y + d * iy[j], d);
461
+ dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
482
462
  }
483
463
  }
484
464
  }
485
465
 
486
- void pairwise_indexed_inner_product (
487
- size_t d, size_t n,
488
- const float * x, const int64_t *ix,
489
- const float * y, const int64_t *iy,
490
- float *dis)
491
- {
466
+ void pairwise_indexed_inner_product(
467
+ size_t d,
468
+ size_t n,
469
+ const float* x,
470
+ const int64_t* ix,
471
+ const float* y,
472
+ const int64_t* iy,
473
+ float* dis) {
492
474
  #pragma omp parallel for
493
475
  for (int64_t j = 0; j < n; j++) {
494
476
  if (ix[j] >= 0 && iy[j] >= 0) {
495
- dis[j] = fvec_inner_product (x + d * ix[j], y + d * iy[j], d);
477
+ dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
496
478
  }
497
479
  }
498
480
  }
499
481
 
500
-
501
482
  /* Find the nearest neighbors for nx queries in a set of ny vectors
502
483
  indexed by ids. May be useful for re-ranking a pre-selected vector list */
503
- void knn_inner_products_by_idx (const float * x,
504
- const float * y,
505
- const int64_t * ids,
506
- size_t d, size_t nx, size_t ny,
507
- float_minheap_array_t * res)
508
- {
484
+ void knn_inner_products_by_idx(
485
+ const float* x,
486
+ const float* y,
487
+ const int64_t* ids,
488
+ size_t d,
489
+ size_t nx,
490
+ size_t ny,
491
+ float_minheap_array_t* res) {
509
492
  size_t k = res->k;
510
493
 
511
494
  #pragma omp parallel for
512
495
  for (int64_t i = 0; i < nx; i++) {
513
- const float * x_ = x + i * d;
514
- const int64_t * idsi = ids + i * ny;
496
+ const float* x_ = x + i * d;
497
+ const int64_t* idsi = ids + i * ny;
515
498
  size_t j;
516
- float * __restrict simi = res->get_val(i);
517
- int64_t * __restrict idxi = res->get_ids (i);
518
- minheap_heapify (k, simi, idxi);
499
+ float* __restrict simi = res->get_val(i);
500
+ int64_t* __restrict idxi = res->get_ids(i);
501
+ minheap_heapify(k, simi, idxi);
519
502
 
520
503
  for (j = 0; j < ny; j++) {
521
- if (idsi[j] < 0) break;
522
- float ip = fvec_inner_product (x_, y + d * idsi[j], d);
504
+ if (idsi[j] < 0)
505
+ break;
506
+ float ip = fvec_inner_product(x_, y + d * idsi[j], d);
523
507
 
524
508
  if (ip > simi[0]) {
525
- minheap_replace_top (k, simi, idxi, ip, idsi[j]);
509
+ minheap_replace_top(k, simi, idxi, ip, idsi[j]);
526
510
  }
527
511
  }
528
- minheap_reorder (k, simi, idxi);
512
+ minheap_reorder(k, simi, idxi);
529
513
  }
530
-
531
514
  }
532
515
 
533
- void knn_L2sqr_by_idx (const float * x,
534
- const float * y,
535
- const int64_t * __restrict ids,
536
- size_t d, size_t nx, size_t ny,
537
- float_maxheap_array_t * res)
538
- {
516
+ void knn_L2sqr_by_idx(
517
+ const float* x,
518
+ const float* y,
519
+ const int64_t* __restrict ids,
520
+ size_t d,
521
+ size_t nx,
522
+ size_t ny,
523
+ float_maxheap_array_t* res) {
539
524
  size_t k = res->k;
540
525
 
541
526
  #pragma omp parallel for
542
527
  for (int64_t i = 0; i < nx; i++) {
543
- const float * x_ = x + i * d;
544
- const int64_t * __restrict idsi = ids + i * ny;
545
- float * __restrict simi = res->get_val(i);
546
- int64_t * __restrict idxi = res->get_ids (i);
547
- maxheap_heapify (res->k, simi, idxi);
528
+ const float* x_ = x + i * d;
529
+ const int64_t* __restrict idsi = ids + i * ny;
530
+ float* __restrict simi = res->get_val(i);
531
+ int64_t* __restrict idxi = res->get_ids(i);
532
+ maxheap_heapify(res->k, simi, idxi);
548
533
  for (size_t j = 0; j < ny; j++) {
549
- float disij = fvec_L2sqr (x_, y + d * idsi[j], d);
534
+ float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
550
535
 
551
536
  if (disij < simi[0]) {
552
- maxheap_replace_top (k, simi, idxi, disij, idsi[j]);
537
+ maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
553
538
  }
554
539
  }
555
- maxheap_reorder (res->k, simi, idxi);
540
+ maxheap_reorder(res->k, simi, idxi);
556
541
  }
557
-
558
542
  }
559
543
 
560
-
561
-
562
-
563
-
564
- void pairwise_L2sqr (int64_t d,
565
- int64_t nq, const float *xq,
566
- int64_t nb, const float *xb,
567
- float *dis,
568
- int64_t ldq, int64_t ldb, int64_t ldd)
569
- {
570
- if (nq == 0 || nb == 0) return;
571
- if (ldq == -1) ldq = d;
572
- if (ldb == -1) ldb = d;
573
- if (ldd == -1) ldd = nb;
544
+ void pairwise_L2sqr(
545
+ int64_t d,
546
+ int64_t nq,
547
+ const float* xq,
548
+ int64_t nb,
549
+ const float* xb,
550
+ float* dis,
551
+ int64_t ldq,
552
+ int64_t ldb,
553
+ int64_t ldd) {
554
+ if (nq == 0 || nb == 0)
555
+ return;
556
+ if (ldq == -1)
557
+ ldq = d;
558
+ if (ldb == -1)
559
+ ldb = d;
560
+ if (ldd == -1)
561
+ ldd = nb;
574
562
 
575
563
  // store in beginning of distance matrix to avoid malloc
576
- float *b_norms = dis;
564
+ float* b_norms = dis;
577
565
 
578
566
  #pragma omp parallel for
579
567
  for (int64_t i = 0; i < nb; i++)
580
- b_norms [i] = fvec_norm_L2sqr (xb + i * ldb, d);
568
+ b_norms[i] = fvec_norm_L2sqr(xb + i * ldb, d);
581
569
 
582
570
  #pragma omp parallel for
583
571
  for (int64_t i = 1; i < nq; i++) {
584
- float q_norm = fvec_norm_L2sqr (xq + i * ldq, d);
572
+ float q_norm = fvec_norm_L2sqr(xq + i * ldq, d);
585
573
  for (int64_t j = 0; j < nb; j++)
586
- dis[i * ldd + j] = q_norm + b_norms [j];
574
+ dis[i * ldd + j] = q_norm + b_norms[j];
587
575
  }
588
576
 
589
577
  {
590
- float q_norm = fvec_norm_L2sqr (xq, d);
578
+ float q_norm = fvec_norm_L2sqr(xq, d);
591
579
  for (int64_t j = 0; j < nb; j++)
592
580
  dis[j] += q_norm;
593
581
  }
@@ -596,22 +584,28 @@ void pairwise_L2sqr (int64_t d,
596
584
  FINTEGER nbi = nb, nqi = nq, di = d, ldqi = ldq, ldbi = ldb, lddi = ldd;
597
585
  float one = 1.0, minus_2 = -2.0;
598
586
 
599
- sgemm_ ("Transposed", "Not transposed",
600
- &nbi, &nqi, &di,
601
- &minus_2,
602
- xb, &ldbi,
603
- xq, &ldqi,
604
- &one, dis, &lddi);
587
+ sgemm_("Transposed",
588
+ "Not transposed",
589
+ &nbi,
590
+ &nqi,
591
+ &di,
592
+ &minus_2,
593
+ xb,
594
+ &ldbi,
595
+ xq,
596
+ &ldqi,
597
+ &one,
598
+ dis,
599
+ &lddi);
605
600
  }
606
-
607
601
  }
608
602
 
609
- void inner_product_to_L2sqr(float* __restrict dis,
610
- const float* nr1,
611
- const float* nr2,
612
- size_t n1, size_t n2)
613
- {
614
-
603
+ void inner_product_to_L2sqr(
604
+ float* __restrict dis,
605
+ const float* nr1,
606
+ const float* nr2,
607
+ size_t n1,
608
+ size_t n2) {
615
609
  #pragma omp parallel for
616
610
  for (int64_t j = 0; j < n1; j++) {
617
611
  float* disj = dis + j * n2;
@@ -620,5 +614,4 @@ void inner_product_to_L2sqr(float* __restrict dis,
620
614
  }
621
615
  }
622
616
 
623
-
624
617
  } // namespace faiss