vectlite 0.1.12 → 0.9.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,16 @@
1
1
  //! Vector quantization module for memory-efficient similarity search.
2
2
  //!
3
3
  //! Supports three quantization strategies:
4
- //! - **Scalar (int8)**: 4x memory reduction with minimal recall loss
5
- //! - **Binary**: 32x memory reduction, uses Hamming distance for fast filtering
4
+ //! - **Scalar (int8)**: compact in-memory candidate index with minimal recall loss
5
+ //! - **Binary**: smallest in-memory candidate index, uses Hamming distance for fast filtering
6
6
  //! - **Product Quantization (PQ)**: Configurable compression for very large datasets
7
7
  //!
8
8
  //! All strategies support a 2-stage pipeline: fast quantized search followed by
9
9
  //! exact float32 rescoring of top candidates.
10
10
 
11
- use std::io::{Read, Write};
11
+ use std::io::{Error, ErrorKind, Read, Write};
12
+
13
+ use crate::{DistanceMetric, Result, VectLiteError};
12
14
 
13
15
  // ---------------------------------------------------------------------------
14
16
  // Public types
@@ -18,10 +20,10 @@ use std::io::{Read, Write};
18
20
  #[derive(Clone, Debug, PartialEq)]
19
21
  pub enum QuantizationConfig {
20
22
  /// Scalar quantization: maps each f32 dimension to int8 using per-dimension
21
- /// min/max calibration. 4x memory reduction.
23
+ /// min/max calibration for a compact in-memory candidate index.
22
24
  Scalar(ScalarQuantizationConfig),
23
25
  /// Binary quantization: maps each f32 dimension to a single bit.
24
- /// 32x memory reduction. Best for high-dimensional normalized embeddings.
26
+ /// Smallest in-memory candidate index. Best for high-dimensional normalized embeddings.
25
27
  Binary(BinaryQuantizationConfig),
26
28
  /// Product quantization: splits vector into sub-vectors and quantizes each
27
29
  /// to a centroid index. Highest compression for large datasets.
@@ -31,14 +33,14 @@ pub enum QuantizationConfig {
31
33
  #[derive(Clone, Debug, PartialEq)]
32
34
  pub struct ScalarQuantizationConfig {
33
35
  /// Number of top candidates from quantized search to rescore with float32.
34
- /// Default: 5x top_k (minimum 100).
36
+ /// Default: 10x top_k.
35
37
  pub rescore_multiplier: usize,
36
38
  }
37
39
 
38
40
  impl Default for ScalarQuantizationConfig {
39
41
  fn default() -> Self {
40
42
  Self {
41
- rescore_multiplier: 5,
43
+ rescore_multiplier: 10,
42
44
  }
43
45
  }
44
46
  }
@@ -46,7 +48,7 @@ impl Default for ScalarQuantizationConfig {
46
48
  #[derive(Clone, Debug, PartialEq)]
47
49
  pub struct BinaryQuantizationConfig {
48
50
  /// Number of top candidates from Hamming search to rescore with float32.
49
- /// Default: 10x top_k (minimum 100).
51
+ /// Default: 10x top_k.
50
52
  pub rescore_multiplier: usize,
51
53
  }
52
54
 
@@ -69,6 +71,7 @@ pub struct ProductQuantizationConfig {
69
71
  /// Number of k-means training iterations.
70
72
  pub training_iterations: usize,
71
73
  /// Number of top candidates from PQ search to rescore with float32.
74
+ /// Default: 10x top_k.
72
75
  pub rescore_multiplier: usize,
73
76
  }
74
77
 
@@ -83,6 +86,53 @@ impl Default for ProductQuantizationConfig {
83
86
  }
84
87
  }
85
88
 
89
+ /// Choose a valid default PQ sub-vector count for a database dimension.
90
+ ///
91
+ /// Prefer the historical default of 16 when possible, then fall back to smaller
92
+ /// common divisors so dimensions such as 100, 146, and 200 do not require an
93
+ /// explicit `num_sub_vectors`.
94
+ pub fn default_product_num_sub_vectors(dimension: usize) -> usize {
95
+ [16, 12, 10, 8, 6, 4, 3, 2, 1]
96
+ .into_iter()
97
+ .find(|candidate| dimension % candidate == 0)
98
+ .unwrap_or(1)
99
+ }
100
+
101
+ /// List every valid PQ sub-vector count for a database dimension.
102
+ pub fn valid_product_num_sub_vectors(dimension: usize) -> Vec<usize> {
103
+ if dimension == 0 {
104
+ return Vec::new();
105
+ }
106
+
107
+ (1..=dimension)
108
+ .filter(|candidate| dimension % candidate == 0)
109
+ .collect()
110
+ }
111
+
112
+ /// Validate quantization settings before an index build can panic.
113
+ pub fn validate_quantization_config(config: &QuantizationConfig, dimension: usize) -> Result<()> {
114
+ if let QuantizationConfig::Product(cfg) = config {
115
+ if cfg.num_sub_vectors == 0 {
116
+ return Err(VectLiteError::InvalidFormat(
117
+ "num_sub_vectors must be greater than 0".to_owned(),
118
+ ));
119
+ }
120
+ if dimension % cfg.num_sub_vectors != 0 {
121
+ return Err(VectLiteError::InvalidFormat(format!(
122
+ "dimension ({dimension}) must be divisible by num_sub_vectors ({})",
123
+ cfg.num_sub_vectors
124
+ )));
125
+ }
126
+ if cfg.num_centroids == 0 || cfg.num_centroids > 256 {
127
+ return Err(VectLiteError::InvalidFormat(
128
+ "num_centroids must be between 1 and 256".to_owned(),
129
+ ));
130
+ }
131
+ }
132
+
133
+ Ok(())
134
+ }
135
+
86
136
  // ---------------------------------------------------------------------------
87
137
  // Scalar Quantization
88
138
  // ---------------------------------------------------------------------------
@@ -173,18 +223,27 @@ impl ScalarQuantizer {
173
223
  .collect()
174
224
  }
175
225
 
176
- /// Compute approximate cosine distance between a quantized query and all stored vectors.
226
+ /// Compute approximate cosine similarity between the query and all stored vectors.
177
227
  /// Returns indices sorted by approximate similarity (best first).
178
228
  pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(usize, f32)> {
179
- let rescore_count = (top_k * self.config.rescore_multiplier)
180
- .max(100)
181
- .min(self.count);
182
- let query_quantized = self.quantize_query(query);
229
+ self.search_with_metric(query, top_k, DistanceMetric::Cosine)
230
+ }
231
+
232
+ /// Compute approximate metric scores between the query and all stored vectors.
233
+ /// Returns indices sorted by approximate score (best first).
234
+ pub fn search_with_metric(
235
+ &self,
236
+ query: &[f32],
237
+ top_k: usize,
238
+ metric: DistanceMetric,
239
+ ) -> Vec<(usize, f32)> {
240
+ assert_eq!(query.len(), self.dimension);
241
+ let rescore_count = rescore_count(top_k, self.config.rescore_multiplier, self.count);
183
242
  let mut scores: Vec<(usize, f32)> = (0..self.count)
184
243
  .map(|idx| {
185
244
  let offset = idx * self.dimension;
186
245
  let code_slice = &self.codes[offset..offset + self.dimension];
187
- let sim = scalar_quantized_dot(&query_quantized, code_slice);
246
+ let sim = self.approximate_score(query, code_slice, metric);
188
247
  (idx, sim)
189
248
  })
190
249
  .collect();
@@ -195,6 +254,71 @@ impl ScalarQuantizer {
195
254
  scores
196
255
  }
197
256
 
257
+ fn approximate_score(&self, query: &[f32], code_slice: &[u8], metric: DistanceMetric) -> f32 {
258
+ match metric {
259
+ DistanceMetric::Cosine => {
260
+ let mut dot = 0.0_f32;
261
+ let mut query_norm = 0.0_f32;
262
+ let mut vector_norm = 0.0_f32;
263
+
264
+ for (((&query_value, &code), &min), &scale) in query
265
+ .iter()
266
+ .zip(code_slice.iter())
267
+ .zip(self.mins.iter())
268
+ .zip(self.scales.iter())
269
+ {
270
+ let value = dequantize_scalar(code, min, scale);
271
+ dot += query_value * value;
272
+ query_norm += query_value * query_value;
273
+ vector_norm += value * value;
274
+ }
275
+
276
+ if query_norm == 0.0 || vector_norm == 0.0 {
277
+ 0.0
278
+ } else {
279
+ dot / (query_norm.sqrt() * vector_norm.sqrt())
280
+ }
281
+ }
282
+ DistanceMetric::Euclidean => {
283
+ let mut sum = 0.0_f32;
284
+ for (((&query_value, &code), &min), &scale) in query
285
+ .iter()
286
+ .zip(code_slice.iter())
287
+ .zip(self.mins.iter())
288
+ .zip(self.scales.iter())
289
+ {
290
+ let delta = query_value - dequantize_scalar(code, min, scale);
291
+ sum += delta * delta;
292
+ }
293
+ -sum.sqrt()
294
+ }
295
+ DistanceMetric::DotProduct => {
296
+ let mut dot = 0.0_f32;
297
+ for (((&query_value, &code), &min), &scale) in query
298
+ .iter()
299
+ .zip(code_slice.iter())
300
+ .zip(self.mins.iter())
301
+ .zip(self.scales.iter())
302
+ {
303
+ dot += query_value * dequantize_scalar(code, min, scale);
304
+ }
305
+ dot
306
+ }
307
+ DistanceMetric::Manhattan => {
308
+ let mut sum = 0.0_f32;
309
+ for (((&query_value, &code), &min), &scale) in query
310
+ .iter()
311
+ .zip(code_slice.iter())
312
+ .zip(self.mins.iter())
313
+ .zip(self.scales.iter())
314
+ {
315
+ sum += (query_value - dequantize_scalar(code, min, scale)).abs();
316
+ }
317
+ -sum
318
+ }
319
+ }
320
+ }
321
+
198
322
  /// Rebuild codes from training vectors (used after deserialization with new vectors).
199
323
  pub fn rebuild_codes(&mut self, vectors: &[&[f32]]) {
200
324
  self.codes.clear();
@@ -311,9 +435,7 @@ impl BinaryQuantizer {
311
435
  /// Search using Hamming distance. Returns candidate indices sorted by
312
436
  /// Hamming similarity (fewest differing bits first).
313
437
  pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(usize, u32)> {
314
- let rescore_count = (top_k * self.config.rescore_multiplier)
315
- .max(100)
316
- .min(self.count);
438
+ let rescore_count = rescore_count(top_k, self.config.rescore_multiplier, self.count);
317
439
  let query_binary = self.binarize_query(query);
318
440
  let mut distances: Vec<(usize, u32)> = (0..self.count)
319
441
  .map(|idx| {
@@ -476,9 +598,7 @@ impl ProductQuantizer {
476
598
  /// Search using asymmetric distance computation (ADC).
477
599
  /// Returns candidate indices sorted by approximate L2 distance.
478
600
  pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(usize, f32)> {
479
- let rescore_count = (top_k * self.config.rescore_multiplier)
480
- .max(100)
481
- .min(self.count);
601
+ let rescore_count = rescore_count(top_k, self.config.rescore_multiplier, self.count);
482
602
  let distance_table = self.compute_distance_table(query);
483
603
 
484
604
  let mut distances: Vec<(usize, f32)> = (0..self.count)
@@ -542,6 +662,20 @@ impl ProductQuantizer {
542
662
  let num_centroids = read_usize(reader)?;
543
663
  let training_iterations = read_usize(reader)?;
544
664
  let rescore_multiplier = read_usize(reader)?;
665
+ if num_sub_vectors == 0 || dimension % num_sub_vectors != 0 {
666
+ return Err(Error::new(
667
+ ErrorKind::InvalidData,
668
+ format!(
669
+ "dimension ({dimension}) must be divisible by num_sub_vectors ({num_sub_vectors})"
670
+ ),
671
+ ));
672
+ }
673
+ if num_centroids == 0 || num_centroids > 256 {
674
+ return Err(Error::new(
675
+ ErrorKind::InvalidData,
676
+ "num_centroids must be between 1 and 256",
677
+ ));
678
+ }
545
679
  let sub_dimension = dimension / num_sub_vectors;
546
680
 
547
681
  // Read codebooks
@@ -578,6 +712,326 @@ impl ProductQuantizer {
578
712
  }
579
713
  }
580
714
 
715
+ // ---------------------------------------------------------------------------
716
+ // Two-Bit Quantization (ColBERTv2-style)
717
+ // ---------------------------------------------------------------------------
718
+
719
+ /// Configuration for 2-bit multi-vector quantization (ColBERTv2-style).
720
+ #[derive(Clone, Debug, PartialEq)]
721
+ pub struct TwoBitQuantizationConfig {
722
+ /// Number of top candidate docs from quantized search to rescore with
723
+ /// exact float32 MaxSim. Default: 4x top_k.
724
+ pub rescore_multiplier: usize,
725
+ }
726
+
727
+ impl Default for TwoBitQuantizationConfig {
728
+ fn default() -> Self {
729
+ Self {
730
+ rescore_multiplier: 4,
731
+ }
732
+ }
733
+ }
734
+
735
+ /// Two-bit quantizer: maps each dimension to 2 bits (4 levels) using
736
+ /// per-dimension quartile boundaries. ~16x compression vs float32.
737
+ /// Designed for ColBERT-style token-level vectors.
738
+ #[derive(Clone, Debug)]
739
+ pub struct TwoBitQuantizer {
740
+ pub dimension: usize,
741
+ /// Per-dimension boundary values: [q25, q50, q75] for each dimension.
742
+ /// Shape: dimension * 3.
743
+ pub boundaries: Vec<f32>,
744
+ /// Quantized codes: 2 bits per dimension, packed into bytes.
745
+ /// Each vector uses ceil(dimension / 4) bytes.
746
+ pub codes: Vec<u8>,
747
+ /// Number of quantized vectors.
748
+ pub count: usize,
749
+ /// Bytes per quantized vector.
750
+ pub bytes_per_vector: usize,
751
+ pub config: TwoBitQuantizationConfig,
752
+ }
753
+
754
+ impl TwoBitQuantizer {
755
+ /// Train a 2-bit quantizer by computing per-dimension quartiles.
756
+ pub fn train(vectors: &[&[f32]], dimension: usize, config: TwoBitQuantizationConfig) -> Self {
757
+ assert!(!vectors.is_empty(), "need at least one vector to train");
758
+
759
+ // Collect values per dimension and compute quartile boundaries
760
+ let mut boundaries = Vec::with_capacity(dimension * 3);
761
+ for d in 0..dimension {
762
+ let mut values: Vec<f32> = vectors.iter().map(|v| v[d]).collect();
763
+ values.sort_unstable_by(|a, b| a.total_cmp(b));
764
+ let n = values.len();
765
+ let q25 = values[n / 4];
766
+ let q50 = values[n / 2];
767
+ let q75 = values[(3 * n) / 4];
768
+ boundaries.push(q25);
769
+ boundaries.push(q50);
770
+ boundaries.push(q75);
771
+ }
772
+
773
+ let bytes_per_vector = (dimension + 3) / 4;
774
+ let mut codes = Vec::with_capacity(vectors.len() * bytes_per_vector);
775
+ for vector in vectors {
776
+ codes.extend_from_slice(&quantize_two_bit(vector, &boundaries, bytes_per_vector));
777
+ }
778
+
779
+ Self {
780
+ dimension,
781
+ boundaries,
782
+ codes,
783
+ count: vectors.len(),
784
+ bytes_per_vector,
785
+ config,
786
+ }
787
+ }
788
+
789
+ /// Quantize a single vector to 2-bit codes.
790
+ pub fn quantize_vector(&self, vector: &[f32]) -> Vec<u8> {
791
+ quantize_two_bit(vector, &self.boundaries, self.bytes_per_vector)
792
+ }
793
+
794
+ /// Compute approximate dot product between a 2-bit quantized query and
795
+ /// a stored quantized vector. Returns a score where higher = more similar.
796
+ pub fn approx_dot(&self, query_codes: &[u8], idx: usize) -> i32 {
797
+ let offset = idx * self.bytes_per_vector;
798
+ let stored = &self.codes[offset..offset + self.bytes_per_vector];
799
+ two_bit_approx_dot(query_codes, stored, self.dimension)
800
+ }
801
+
802
+ /// Search for top-k candidates using approximate 2-bit dot products.
803
+ /// Returns (index, approx_score) pairs sorted best-first.
804
+ pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(usize, i32)> {
805
+ let rescore_count = rescore_count(top_k, self.config.rescore_multiplier, self.count);
806
+ let query_codes = self.quantize_vector(query);
807
+
808
+ let mut scores: Vec<(usize, i32)> = (0..self.count)
809
+ .map(|idx| (idx, self.approx_dot(&query_codes, idx)))
810
+ .collect();
811
+
812
+ scores.sort_unstable_by(|a, b| b.1.cmp(&a.1));
813
+ scores.truncate(rescore_count);
814
+ scores
815
+ }
816
+
817
+ /// Rebuild codes from vectors.
818
+ pub fn rebuild_codes(&mut self, vectors: &[&[f32]]) {
819
+ self.codes.clear();
820
+ self.codes.reserve(vectors.len() * self.bytes_per_vector);
821
+ for vector in vectors {
822
+ self.codes.extend_from_slice(&quantize_two_bit(
823
+ vector,
824
+ &self.boundaries,
825
+ self.bytes_per_vector,
826
+ ));
827
+ }
828
+ self.count = vectors.len();
829
+ }
830
+
831
+ /// Serialize parameters (boundaries only, codes rebuilt on load).
832
+ pub fn write_params(&self, writer: &mut impl Write) -> std::io::Result<()> {
833
+ // Tag byte: 4 = two_bit
834
+ writer.write_all(&[4u8])?;
835
+ write_usize(writer, self.dimension)?;
836
+ write_usize(writer, self.config.rescore_multiplier)?;
837
+ // Write boundaries (dimension * 3 floats)
838
+ for &b in &self.boundaries {
839
+ writer.write_all(&b.to_le_bytes())?;
840
+ }
841
+ Ok(())
842
+ }
843
+
844
+ /// Deserialize parameters.
845
+ pub fn read_params(reader: &mut impl Read) -> std::io::Result<Self> {
846
+ let dimension = read_usize(reader)?;
847
+ let rescore_multiplier = read_usize(reader)?;
848
+ let mut boundaries = vec![0.0_f32; dimension * 3];
849
+ for b in &mut boundaries {
850
+ let mut buf = [0u8; 4];
851
+ reader.read_exact(&mut buf)?;
852
+ *b = f32::from_le_bytes(buf);
853
+ }
854
+ let bytes_per_vector = (dimension + 3) / 4;
855
+ Ok(Self {
856
+ dimension,
857
+ boundaries,
858
+ codes: Vec::new(),
859
+ count: 0,
860
+ bytes_per_vector,
861
+ config: TwoBitQuantizationConfig { rescore_multiplier },
862
+ })
863
+ }
864
+ }
865
+
866
+ // ---------------------------------------------------------------------------
867
+ // Multi-vector quantized index (for ColBERT token-level search)
868
+ // ---------------------------------------------------------------------------
869
+
870
+ /// Configuration for multi-vector quantization.
871
+ #[derive(Clone, Debug, PartialEq)]
872
+ pub enum MultiVectorQuantizationConfig {
873
+ TwoBit(TwoBitQuantizationConfig),
874
+ }
875
+
876
+ /// A quantized index for multi-vector (late interaction) search.
877
+ /// Stores all token vectors from all documents in a flat quantized array,
878
+ /// with a mapping from document index to token range.
879
+ #[derive(Clone, Debug)]
880
+ pub struct MultiVectorQuantizedIndex {
881
+ pub quantizer: TwoBitQuantizer,
882
+ /// For each document: (start_index, count) into the quantized vector array.
883
+ pub doc_ranges: Vec<(usize, usize)>,
884
+ }
885
+
886
+ impl MultiVectorQuantizedIndex {
887
+ /// Build a multi-vector quantized index from per-document token vectors.
888
+ /// `doc_token_vectors[i]` is a slice of token-level vectors for document i.
889
+ pub fn build(
890
+ doc_token_vectors: &[&[Vec<f32>]],
891
+ token_dimension: usize,
892
+ config: &MultiVectorQuantizationConfig,
893
+ ) -> Self {
894
+ // Flatten all token vectors for training
895
+ let all_tokens: Vec<&[f32]> = doc_token_vectors
896
+ .iter()
897
+ .flat_map(|tokens| tokens.iter().map(|v| v.as_slice()))
898
+ .collect();
899
+
900
+ let MultiVectorQuantizationConfig::TwoBit(cfg) = config;
901
+
902
+ let quantizer = if all_tokens.is_empty() {
903
+ // Empty case: create minimal quantizer
904
+ TwoBitQuantizer {
905
+ dimension: token_dimension,
906
+ boundaries: vec![0.0; token_dimension * 3],
907
+ codes: Vec::new(),
908
+ count: 0,
909
+ bytes_per_vector: (token_dimension + 3) / 4,
910
+ config: cfg.clone(),
911
+ }
912
+ } else {
913
+ TwoBitQuantizer::train(&all_tokens, token_dimension, cfg.clone())
914
+ };
915
+
916
+ // Build doc_ranges
917
+ let mut doc_ranges = Vec::with_capacity(doc_token_vectors.len());
918
+ let mut offset = 0;
919
+ for tokens in doc_token_vectors {
920
+ doc_ranges.push((offset, tokens.len()));
921
+ offset += tokens.len();
922
+ }
923
+
924
+ Self {
925
+ quantizer,
926
+ doc_ranges,
927
+ }
928
+ }
929
+
930
+ /// Compute approximate MaxSim score for a document given query token codes.
931
+ /// For each query token, finds the max approximate dot with any document token.
932
+ pub fn approx_maxsim(&self, query_codes: &[Vec<u8>], doc_idx: usize) -> i32 {
933
+ let (start, count) = self.doc_ranges[doc_idx];
934
+ if count == 0 || query_codes.is_empty() {
935
+ return 0;
936
+ }
937
+ let mut total = 0i32;
938
+ for q_code in query_codes {
939
+ let mut best = i32::MIN;
940
+ for i in start..start + count {
941
+ let score = two_bit_approx_dot(
942
+ q_code,
943
+ &self.quantizer.codes[i * self.quantizer.bytes_per_vector
944
+ ..(i + 1) * self.quantizer.bytes_per_vector],
945
+ self.quantizer.dimension,
946
+ );
947
+ if score > best {
948
+ best = score;
949
+ }
950
+ }
951
+ total += best;
952
+ }
953
+ total
954
+ }
955
+
956
+ /// Search: returns candidate document indices sorted by approximate MaxSim.
957
+ pub fn search(&self, query_tokens: &[&[f32]], top_k: usize) -> Vec<usize> {
958
+ let rescore_count = rescore_count(
959
+ top_k,
960
+ self.quantizer.config.rescore_multiplier,
961
+ self.doc_ranges.len(),
962
+ );
963
+ if query_tokens.is_empty() || self.doc_ranges.is_empty() {
964
+ return Vec::new();
965
+ }
966
+
967
+ let query_codes: Vec<Vec<u8>> = query_tokens
968
+ .iter()
969
+ .map(|t| self.quantizer.quantize_vector(t))
970
+ .collect();
971
+
972
+ let mut scores: Vec<(usize, i32)> = (0..self.doc_ranges.len())
973
+ .map(|doc_idx| (doc_idx, self.approx_maxsim(&query_codes, doc_idx)))
974
+ .collect();
975
+
976
+ scores.sort_unstable_by(|a, b| b.1.cmp(&a.1));
977
+ scores.truncate(rescore_count);
978
+ scores.into_iter().map(|(idx, _)| idx).collect()
979
+ }
980
+
981
+ /// Rebuild from document token vectors (after loading parameters from disk).
982
+ pub fn rebuild(&mut self, doc_token_vectors: &[&[Vec<f32>]]) {
983
+ let all_tokens: Vec<&[f32]> = doc_token_vectors
984
+ .iter()
985
+ .flat_map(|tokens| tokens.iter().map(|v| v.as_slice()))
986
+ .collect();
987
+ self.quantizer.rebuild_codes(&all_tokens);
988
+
989
+ self.doc_ranges.clear();
990
+ let mut offset = 0;
991
+ for tokens in doc_token_vectors {
992
+ self.doc_ranges.push((offset, tokens.len()));
993
+ offset += tokens.len();
994
+ }
995
+ }
996
+
997
+ /// Serialize parameters.
998
+ pub fn write_params(&self, writer: &mut impl Write) -> std::io::Result<()> {
999
+ self.quantizer.write_params(writer)?;
1000
+ // Write doc_ranges
1001
+ write_usize(writer, self.doc_ranges.len())?;
1002
+ for &(start, count) in &self.doc_ranges {
1003
+ write_usize(writer, start)?;
1004
+ write_usize(writer, count)?;
1005
+ }
1006
+ Ok(())
1007
+ }
1008
+
1009
+ /// Deserialize parameters.
1010
+ pub fn read_params(reader: &mut impl Read) -> std::io::Result<Self> {
1011
+ // Consume the tag byte written by TwoBitQuantizer::write_params
1012
+ let mut tag = [0u8; 1];
1013
+ reader.read_exact(&mut tag)?;
1014
+ if tag[0] != 4 {
1015
+ return Err(std::io::Error::new(
1016
+ std::io::ErrorKind::InvalidData,
1017
+ format!("expected two_bit tag (4), got {}", tag[0]),
1018
+ ));
1019
+ }
1020
+ let quantizer = TwoBitQuantizer::read_params(reader)?;
1021
+ let num_docs = read_usize(reader)?;
1022
+ let mut doc_ranges = Vec::with_capacity(num_docs);
1023
+ for _ in 0..num_docs {
1024
+ let start = read_usize(reader)?;
1025
+ let count = read_usize(reader)?;
1026
+ doc_ranges.push((start, count));
1027
+ }
1028
+ Ok(Self {
1029
+ quantizer,
1030
+ doc_ranges,
1031
+ })
1032
+ }
1033
+ }
1034
+
581
1035
  // ---------------------------------------------------------------------------
582
1036
  // Unified quantization index
583
1037
  // ---------------------------------------------------------------------------
@@ -612,10 +1066,23 @@ impl QuantizedIndex {
612
1066
  /// Search the quantized index. Returns candidate indices sorted by
613
1067
  /// approximate similarity (best first), to be rescored with exact vectors.
614
1068
  pub fn search_candidates(&self, query: &[f32], top_k: usize) -> Vec<usize> {
1069
+ self.search_candidates_with_metric(query, top_k, DistanceMetric::Cosine)
1070
+ }
1071
+
1072
+ /// Search the quantized index with the database metric.
1073
+ /// Returns candidate indices sorted by approximate score (best first).
1074
+ pub fn search_candidates_with_metric(
1075
+ &self,
1076
+ query: &[f32],
1077
+ top_k: usize,
1078
+ metric: DistanceMetric,
1079
+ ) -> Vec<usize> {
615
1080
  match self {
616
- QuantizedIndex::Scalar(q) => {
617
- q.search(query, top_k).into_iter().map(|(i, _)| i).collect()
618
- }
1081
+ QuantizedIndex::Scalar(q) => q
1082
+ .search_with_metric(query, top_k, metric)
1083
+ .into_iter()
1084
+ .map(|(i, _)| i)
1085
+ .collect(),
619
1086
  QuantizedIndex::Binary(q) => {
620
1087
  q.search(query, top_k).into_iter().map(|(i, _)| i).collect()
621
1088
  }
@@ -696,6 +1163,14 @@ impl QuantizedIndex {
696
1163
  // Internal helper functions
697
1164
  // ---------------------------------------------------------------------------
698
1165
 
1166
+ #[inline]
1167
+ fn rescore_count(top_k: usize, rescore_multiplier: usize, count: usize) -> usize {
1168
+ top_k
1169
+ .max(1)
1170
+ .saturating_mul(rescore_multiplier.max(1))
1171
+ .min(count)
1172
+ }
1173
+
699
1174
  /// Quantize a single f32 value to u8 using the given min and scale.
700
1175
  #[inline]
701
1176
  fn quantize_scalar(val: f32, min: f32, scale: f32) -> u8 {
@@ -706,15 +1181,13 @@ fn quantize_scalar(val: f32, min: f32, scale: f32) -> u8 {
706
1181
  }
707
1182
  }
708
1183
 
709
- /// Approximate dot product between two u8-quantized vectors.
710
- /// Higher value = more similar (analogous to cosine similarity for normalized vectors).
711
1184
  #[inline]
712
- fn scalar_quantized_dot(a: &[u8], b: &[u8]) -> f32 {
713
- let mut sum = 0i32;
714
- for (&ai, &bi) in a.iter().zip(b.iter()) {
715
- sum += (ai as i32) * (bi as i32);
1185
+ fn dequantize_scalar(code: u8, min: f32, scale: f32) -> f32 {
1186
+ if scale == 0.0 {
1187
+ min
1188
+ } else {
1189
+ min + (code as f32 / scale)
716
1190
  }
717
- sum as f32
718
1191
  }
719
1192
 
720
1193
  /// Convert a float vector to a binary representation (1 bit per dimension).
@@ -740,6 +1213,45 @@ fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
740
1213
  dist
741
1214
  }
742
1215
 
1216
+ /// Quantize a float vector to 2-bit codes (4 levels per dimension).
1217
+ /// Level mapping: val <= q25 → 0, val <= q50 → 1, val <= q75 → 2, else → 3.
1218
+ /// Packed 4 dimensions per byte (least-significant bits first).
1219
+ fn quantize_two_bit(vector: &[f32], boundaries: &[f32], bytes_per_vector: usize) -> Vec<u8> {
1220
+ let mut result = vec![0u8; bytes_per_vector];
1221
+ for (i, &val) in vector.iter().enumerate() {
1222
+ let b_offset = i * 3;
1223
+ let level = if val <= boundaries[b_offset] {
1224
+ 0u8
1225
+ } else if val <= boundaries[b_offset + 1] {
1226
+ 1u8
1227
+ } else if val <= boundaries[b_offset + 2] {
1228
+ 2u8
1229
+ } else {
1230
+ 3u8
1231
+ };
1232
+ let byte_idx = i / 4;
1233
+ let bit_offset = (i % 4) * 2;
1234
+ result[byte_idx] |= level << bit_offset;
1235
+ }
1236
+ result
1237
+ }
1238
+
1239
+ /// Approximate dot product between two 2-bit quantized vectors.
1240
+ /// Uses level values 0,1,2,3 as proxies for the original float magnitudes.
1241
+ /// Higher score = more similar.
1242
+ #[inline]
1243
+ fn two_bit_approx_dot(a: &[u8], b: &[u8], dimension: usize) -> i32 {
1244
+ let mut sum = 0i32;
1245
+ for i in 0..dimension {
1246
+ let byte_idx = i / 4;
1247
+ let bit_offset = (i % 4) * 2;
1248
+ let a_level = ((a[byte_idx] >> bit_offset) & 0x03) as i32;
1249
+ let b_level = ((b[byte_idx] >> bit_offset) & 0x03) as i32;
1250
+ sum += a_level * b_level;
1251
+ }
1252
+ sum
1253
+ }
1254
+
743
1255
  /// Squared L2 distance between two vectors.
744
1256
  #[inline]
745
1257
  fn l2_distance_sq(a: &[f32], b: &[f32]) -> f32 {
@@ -895,6 +1407,39 @@ mod tests {
895
1407
  assert_eq!(results[0].1, 0); // Hamming distance 0
896
1408
  }
897
1409
 
1410
+ #[test]
1411
+ fn rescore_multiplier_controls_candidate_count_without_hidden_floor() {
1412
+ let vectors = random_vectors(200, 64, 7);
1413
+ let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
1414
+
1415
+ let scalar = ScalarQuantizer::train(
1416
+ &refs,
1417
+ 64,
1418
+ ScalarQuantizationConfig {
1419
+ rescore_multiplier: 1,
1420
+ },
1421
+ );
1422
+ assert_eq!(scalar.search(&vectors[0], 10).len(), 10);
1423
+
1424
+ let scalar = ScalarQuantizer::train(
1425
+ &refs,
1426
+ 64,
1427
+ ScalarQuantizationConfig {
1428
+ rescore_multiplier: 4,
1429
+ },
1430
+ );
1431
+ assert_eq!(scalar.search(&vectors[0], 10).len(), 40);
1432
+
1433
+ let mut binary = BinaryQuantizer::new(
1434
+ 64,
1435
+ BinaryQuantizationConfig {
1436
+ rescore_multiplier: 2,
1437
+ },
1438
+ );
1439
+ binary.add_vectors(&refs);
1440
+ assert_eq!(binary.search(&vectors[0], 10).len(), 20);
1441
+ }
1442
+
898
1443
  #[test]
899
1444
  fn product_quantization_basic() {
900
1445
  let vectors = random_vectors(200, 128, 42);
@@ -1084,4 +1629,144 @@ mod tests {
1084
1629
  // Bit 7: 0.9 > 0 -> 1
1085
1630
  assert_eq!(binary[0], 0b10100101);
1086
1631
  }
1632
+
1633
+ #[test]
1634
+ fn two_bit_quantization_basic() {
1635
+ let vectors = random_vectors(100, 64, 42);
1636
+ let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
1637
+
1638
+ let config = TwoBitQuantizationConfig {
1639
+ rescore_multiplier: 4,
1640
+ };
1641
+ let quantizer = TwoBitQuantizer::train(&refs, 64, config);
1642
+
1643
+ assert_eq!(quantizer.dimension, 64);
1644
+ assert_eq!(quantizer.count, 100);
1645
+ assert_eq!(quantizer.bytes_per_vector, 16); // 64 dims * 2 bits / 8 = 16
1646
+ assert_eq!(quantizer.boundaries.len(), 64 * 3);
1647
+
1648
+ // Search should return candidates including the query itself
1649
+ let results = quantizer.search(&vectors[0], 10);
1650
+ assert!(!results.is_empty());
1651
+ assert!(results.iter().take(5).any(|(idx, _)| *idx == 0));
1652
+ }
1653
+
1654
+ #[test]
1655
+ fn two_bit_quantize_and_approx_dot() {
1656
+ // Manually test quantization of a small vector
1657
+ let boundaries = vec![
1658
+ -0.5, 0.0, 0.5, // dim 0: quartiles
1659
+ -0.5, 0.0, 0.5, // dim 1
1660
+ -0.5, 0.0, 0.5, // dim 2
1661
+ -0.5, 0.0, 0.5, // dim 3
1662
+ ];
1663
+ let bytes_per_vector = 1; // 4 dims * 2 bits = 8 bits = 1 byte
1664
+
1665
+ // Vector with values that map to different quantization levels
1666
+ let v1 = [-1.0, -0.25, 0.25, 1.0]; // levels: 0, 1, 2, 3
1667
+ let v2 = [-1.0, -0.25, 0.25, 1.0]; // levels: 0, 1, 2, 3
1668
+
1669
+ let q1 = quantize_two_bit(&v1, &boundaries, bytes_per_vector);
1670
+ let q2 = quantize_two_bit(&v2, &boundaries, bytes_per_vector);
1671
+
1672
+ // Same vectors should have the maximum approx dot product
1673
+ let dot = two_bit_approx_dot(&q1, &q2, 4);
1674
+ assert!(dot > 0); // 0*0 + 1*1 + 2*2 + 3*3 = 0 + 1 + 4 + 9 = 14
1675
+ assert_eq!(dot, 14);
1676
+ }
1677
+
1678
+ #[test]
1679
+ fn two_bit_serialization_roundtrip() {
1680
+ use std::io::Read;
1681
+
1682
+ let vectors = random_vectors(50, 32, 99);
1683
+ let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
1684
+
1685
+ let config = TwoBitQuantizationConfig {
1686
+ rescore_multiplier: 6,
1687
+ };
1688
+ let original = TwoBitQuantizer::train(&refs, 32, config);
1689
+
1690
+ let mut buf = Vec::new();
1691
+ original.write_params(&mut buf).unwrap();
1692
+
1693
+ let mut cursor = std::io::Cursor::new(&buf);
1694
+ // Consume the tag byte written by write_params
1695
+ let mut tag = [0u8; 1];
1696
+ cursor.read_exact(&mut tag).unwrap();
1697
+ assert_eq!(tag[0], 4);
1698
+ let restored = TwoBitQuantizer::read_params(&mut cursor).unwrap();
1699
+
1700
+ assert_eq!(original.dimension, restored.dimension);
1701
+ assert_eq!(original.boundaries.len(), restored.boundaries.len());
1702
+ for (a, b) in original.boundaries.iter().zip(restored.boundaries.iter()) {
1703
+ assert!((a - b).abs() < 1e-6);
1704
+ }
1705
+ assert_eq!(
1706
+ original.config.rescore_multiplier,
1707
+ restored.config.rescore_multiplier
1708
+ );
1709
+ }
1710
+
1711
+ #[test]
1712
+ fn multi_vector_quantized_index_basic() {
1713
+ // Create 5 "documents", each with 3-5 token vectors of dimension 16
1714
+ let mut doc_tokens: Vec<Vec<Vec<f32>>> = Vec::new();
1715
+ for doc_idx in 0..5 {
1716
+ let n_tokens = 3 + (doc_idx % 3); // 3, 4, 5, 3, 4 tokens
1717
+ let tokens = random_vectors(n_tokens, 16, 100 + doc_idx as u64);
1718
+ doc_tokens.push(tokens);
1719
+ }
1720
+
1721
+ let doc_refs: Vec<&[Vec<f32>]> = doc_tokens.iter().map(|v| v.as_slice()).collect();
1722
+ let config = MultiVectorQuantizationConfig::TwoBit(TwoBitQuantizationConfig {
1723
+ rescore_multiplier: 4,
1724
+ });
1725
+
1726
+ let index = MultiVectorQuantizedIndex::build(&doc_refs, 16, &config);
1727
+
1728
+ assert_eq!(index.doc_ranges.len(), 5);
1729
+ // Total token count: 3+4+5+3+4 = 19
1730
+ let total_tokens: usize = index.doc_ranges.iter().map(|(_, count)| count).sum();
1731
+ assert_eq!(total_tokens, 19);
1732
+
1733
+ // Search with a query that matches document 0's tokens
1734
+ let query_tokens: Vec<&[f32]> = doc_tokens[0].iter().map(Vec::as_slice).collect();
1735
+ let results = index.search(&query_tokens, 3);
1736
+ assert!(!results.is_empty());
1737
+ // Document 0 should be among top results (its own tokens should
1738
+ // score highest MaxSim against themselves)
1739
+ assert!(results.iter().take(3).any(|&idx| idx == 0));
1740
+ }
1741
+
1742
+ #[test]
1743
+ fn multi_vector_quantized_index_serialization_roundtrip() {
1744
+ let mut doc_tokens: Vec<Vec<Vec<f32>>> = Vec::new();
1745
+ for i in 0..3 {
1746
+ doc_tokens.push(random_vectors(4, 8, 200 + i));
1747
+ }
1748
+ let doc_refs: Vec<&[Vec<f32>]> = doc_tokens.iter().map(|v| v.as_slice()).collect();
1749
+
1750
+ let config = MultiVectorQuantizationConfig::TwoBit(TwoBitQuantizationConfig {
1751
+ rescore_multiplier: 2,
1752
+ });
1753
+ let original = MultiVectorQuantizedIndex::build(&doc_refs, 8, &config);
1754
+
1755
+ let mut buf = Vec::new();
1756
+ original.write_params(&mut buf).unwrap();
1757
+
1758
+ let mut cursor = std::io::Cursor::new(&buf);
1759
+ let restored = MultiVectorQuantizedIndex::read_params(&mut cursor).unwrap();
1760
+
1761
+ assert_eq!(original.doc_ranges, restored.doc_ranges);
1762
+ assert_eq!(original.quantizer.dimension, restored.quantizer.dimension);
1763
+ assert_eq!(
1764
+ original.quantizer.boundaries.len(),
1765
+ restored.quantizer.boundaries.len()
1766
+ );
1767
+ assert_eq!(
1768
+ original.quantizer.config.rescore_multiplier,
1769
+ restored.quantizer.config.rescore_multiplier,
1770
+ );
1771
+ }
1087
1772
  }