gte 0.0.15 → 0.0.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile +0 -1
- data/README.md +112 -82
- data/Rakefile +0 -9
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +1 -1
- data/ext/gte/src/embedder.rs +29 -65
- data/ext/gte/src/lib.rs +1 -0
- data/ext/gte/src/model_config.rs +0 -4
- data/ext/gte/src/pipeline.rs +8 -9
- data/ext/gte/src/postprocess.rs +8 -6
- data/ext/gte/src/reranker.rs +7 -10
- data/ext/gte/src/ruby_embedder.rs +10 -33
- data/ext/gte/src/session.rs +50 -156
- data/ext/gte/src/tokenizer.rs +45 -38
- data/ext/gte/tests/embedder_unit_test.rs +1 -1
- data/ext/gte/tests/padding_regression_test.rs +7 -25
- data/ext/gte/tests/tokenizer_unit_test.rs +7 -7
- data/lib/gte/config.rb +1 -2
- data/lib/gte/embedder.rb +2 -14
- data/lib/gte/model.rb +0 -7
- data/lib/gte/reranker.rb +14 -33
- data/lib/gte.rb +4 -25
- metadata +1 -1
|
@@ -1,40 +1,24 @@
|
|
|
1
|
-
// Regression tests for the fixed-padding performance bug.
|
|
2
|
-
//
|
|
3
|
-
// Root cause: PaddingMode::Auto silently read "padding.strategy.Fixed: N" from
|
|
4
|
-
// tokenizer.json and applied it, padding every input to max_length tokens.
|
|
5
|
-
// A query like "cat" (1 token) was padded to 64 tokens for Siglip2, making
|
|
6
|
-
// inference ~6x slower (44ms vs 7ms measured on Heroku).
|
|
7
|
-
//
|
|
8
|
-
// These tests use tests/fixtures/minimal/tokenizer.json which has
|
|
9
|
-
// "padding.strategy.Fixed: 64" baked in — exactly the condition that triggered
|
|
10
|
-
// the regression in production models like Siglip2.
|
|
11
|
-
|
|
12
1
|
use gte::model_config::PaddingMode;
|
|
13
2
|
use gte::tokenizer::Tokenizer;
|
|
14
3
|
|
|
15
4
|
const TOKENIZER: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/minimal/tokenizer.json");
|
|
16
5
|
|
|
17
|
-
// Short input tokenizes to 1 token with this vocabulary.
|
|
18
6
|
const SHORT_INPUT: &str = "cat";
|
|
19
7
|
const MAX_LENGTH: usize = 64;
|
|
20
8
|
|
|
21
9
|
#[test]
|
|
22
10
|
fn auto_padding_uses_batch_longest_regardless_of_tokenizer_json() {
|
|
23
|
-
// fixed_padding_length: Some(MAX_LENGTH) simulates what model_profile::read_tokenizer_profile
|
|
24
|
-
// returns when tokenizer.json has "padding.strategy.Fixed: 64".
|
|
25
11
|
let tokenizer = Tokenizer::new(TOKENIZER, MAX_LENGTH, false, PaddingMode::Auto, Some(MAX_LENGTH))
|
|
26
12
|
.expect("tokenizer should load");
|
|
27
13
|
|
|
28
14
|
let tokenized = tokenizer.tokenize(&[SHORT_INPUT.to_string()]).expect("tokenize should succeed");
|
|
29
15
|
|
|
30
|
-
// Old behavior: cols == 64 (silently padded to max_length)
|
|
31
|
-
// New behavior: cols == actual token count (1 for "cat")
|
|
32
16
|
assert!(
|
|
33
|
-
tokenized.
|
|
17
|
+
tokenized.input_ids.ncols() < MAX_LENGTH,
|
|
34
18
|
"Auto padding should use batch_longest, got cols={} (expected < {}). \
|
|
35
19
|
This is the Siglip2 regression: short queries were padded to max_length, \
|
|
36
20
|
making inference ~6x slower.",
|
|
37
|
-
tokenized.
|
|
21
|
+
tokenized.input_ids.ncols(),
|
|
38
22
|
MAX_LENGTH
|
|
39
23
|
);
|
|
40
24
|
}
|
|
@@ -46,7 +30,7 @@ fn fixed_padding_mode_pads_to_max_length() {
|
|
|
46
30
|
|
|
47
31
|
let tokenized = tokenizer.tokenize(&[SHORT_INPUT.to_string()]).expect("tokenize should succeed");
|
|
48
32
|
|
|
49
|
-
assert_eq!(tokenized.
|
|
33
|
+
assert_eq!(tokenized.input_ids.ncols(), MAX_LENGTH, "Fixed mode should pad to max_length");
|
|
50
34
|
assert_eq!(tokenized.input_ids.len(), MAX_LENGTH);
|
|
51
35
|
assert_eq!(tokenized.attn_masks.len(), MAX_LENGTH);
|
|
52
36
|
}
|
|
@@ -56,26 +40,24 @@ fn batch_longest_padding_uses_longest_sequence_in_batch() {
|
|
|
56
40
|
let tokenizer =
|
|
57
41
|
Tokenizer::new(TOKENIZER, MAX_LENGTH, false, PaddingMode::BatchLongest, None).expect("tokenizer should load");
|
|
58
42
|
|
|
59
|
-
// "cat" = 1 token, "hello world" = 2 tokens — batch pads to 2, not 64
|
|
60
43
|
let tokenized =
|
|
61
44
|
tokenizer.tokenize(&["cat".to_string(), "hello world".to_string()]).expect("tokenize should succeed");
|
|
62
45
|
|
|
63
|
-
assert_eq!(tokenized.
|
|
46
|
+
assert_eq!(tokenized.input_ids.nrows(), 2);
|
|
64
47
|
assert!(
|
|
65
|
-
tokenized.
|
|
48
|
+
tokenized.input_ids.ncols() < MAX_LENGTH,
|
|
66
49
|
"BatchLongest should pad to longest in batch (2 tokens), not max_length ({}). Got cols={}",
|
|
67
50
|
MAX_LENGTH,
|
|
68
|
-
tokenized.
|
|
51
|
+
tokenized.input_ids.ncols()
|
|
69
52
|
);
|
|
70
53
|
}
|
|
71
54
|
|
|
72
55
|
#[test]
|
|
73
56
|
fn auto_padding_with_no_fixed_hint_also_uses_batch_longest() {
|
|
74
|
-
// Sanity check: Auto with fixed_padding_length=None also uses BatchLongest
|
|
75
57
|
let tokenizer =
|
|
76
58
|
Tokenizer::new(TOKENIZER, MAX_LENGTH, false, PaddingMode::Auto, None).expect("tokenizer should load");
|
|
77
59
|
|
|
78
60
|
let tokenized = tokenizer.tokenize(&[SHORT_INPUT.to_string()]).expect("tokenize should succeed");
|
|
79
61
|
|
|
80
|
-
assert!(tokenized.
|
|
62
|
+
assert!(tokenized.input_ids.ncols() < MAX_LENGTH);
|
|
81
63
|
}
|
|
@@ -12,13 +12,13 @@ fn test_e5_tokenizer_output_shape() {
|
|
|
12
12
|
|
|
13
13
|
let tokenized = tokenizer.tokenize(&texts).expect("tokenize should succeed");
|
|
14
14
|
|
|
15
|
-
assert_eq!(tokenized.
|
|
16
|
-
assert!(tokenized.
|
|
17
|
-
assert_eq!(tokenized.input_ids.len(), tokenized.
|
|
18
|
-
assert_eq!(tokenized.attn_masks.len(), tokenized.
|
|
15
|
+
assert_eq!(tokenized.input_ids.nrows(), 2, "batch size should be 2");
|
|
16
|
+
assert!(tokenized.input_ids.ncols() > 0, "sequence length should be non-zero");
|
|
17
|
+
assert_eq!(tokenized.input_ids.len(), tokenized.input_ids.nrows() * tokenized.input_ids.ncols());
|
|
18
|
+
assert_eq!(tokenized.attn_masks.len(), tokenized.attn_masks.nrows() * tokenized.attn_masks.ncols());
|
|
19
19
|
|
|
20
20
|
let type_ids = tokenized.type_ids.as_ref().expect("type_ids should exist");
|
|
21
|
-
assert_eq!(type_ids.len(),
|
|
21
|
+
assert_eq!(type_ids.len(), type_ids.nrows() * type_ids.ncols());
|
|
22
22
|
}
|
|
23
23
|
|
|
24
24
|
#[test]
|
|
@@ -31,6 +31,6 @@ fn test_e5_truncation_at_max_length() {
|
|
|
31
31
|
let long_text = "word ".repeat(200);
|
|
32
32
|
let tokenized = tokenizer.tokenize(&[long_text]).expect("tokenize should not error on long input");
|
|
33
33
|
|
|
34
|
-
assert_eq!(tokenized.
|
|
35
|
-
assert_eq!(tokenized.
|
|
34
|
+
assert_eq!(tokenized.input_ids.nrows(), 1);
|
|
35
|
+
assert_eq!(tokenized.input_ids.ncols(), 16, "sequence length should be truncated to max_length");
|
|
36
36
|
}
|
data/lib/gte/config.rb
CHANGED
|
@@ -4,8 +4,7 @@ module GTE
|
|
|
4
4
|
module Config
|
|
5
5
|
Text = Data.define(
|
|
6
6
|
:model_dir, :optimization_level,
|
|
7
|
-
:model_name, :
|
|
8
|
-
:lowercase_input, :max_input_chars
|
|
7
|
+
:model_name, :output_tensor, :max_length, :padding, :execution_providers
|
|
9
8
|
)
|
|
10
9
|
|
|
11
10
|
Reranker = Data.define(
|
data/lib/gte/embedder.rb
CHANGED
|
@@ -5,24 +5,15 @@ module GTE
|
|
|
5
5
|
DEFAULT_OPTIMIZATION_LEVEL = 3
|
|
6
6
|
|
|
7
7
|
class << self
|
|
8
|
-
def config(model_dir)
|
|
9
|
-
cfg = default_config(model_dir)
|
|
10
|
-
cfg = yield(cfg) if block_given?
|
|
11
|
-
from_config(cfg)
|
|
12
|
-
end
|
|
13
|
-
|
|
14
8
|
def from_config(config)
|
|
15
9
|
new(
|
|
16
10
|
config.model_dir,
|
|
17
11
|
config.optimization_level,
|
|
18
12
|
config.model_name.to_s,
|
|
19
|
-
config.normalize,
|
|
20
13
|
config.output_tensor.to_s,
|
|
21
14
|
config.max_length || 0,
|
|
22
15
|
config.padding.to_s,
|
|
23
|
-
config.execution_providers.to_s
|
|
24
|
-
config.lowercase_input ? true : false,
|
|
25
|
-
config.max_input_chars || 0
|
|
16
|
+
config.execution_providers.to_s
|
|
26
17
|
)
|
|
27
18
|
end
|
|
28
19
|
|
|
@@ -31,13 +22,10 @@ module GTE
|
|
|
31
22
|
model_dir: File.expand_path(model_dir),
|
|
32
23
|
optimization_level: DEFAULT_OPTIMIZATION_LEVEL,
|
|
33
24
|
model_name: nil,
|
|
34
|
-
normalize: true,
|
|
35
25
|
output_tensor: nil,
|
|
36
26
|
max_length: nil,
|
|
37
27
|
padding: nil,
|
|
38
|
-
execution_providers: nil
|
|
39
|
-
lowercase_input: false,
|
|
40
|
-
max_input_chars: nil
|
|
28
|
+
execution_providers: nil
|
|
41
29
|
)
|
|
42
30
|
end
|
|
43
31
|
end
|
data/lib/gte/model.rb
CHANGED
data/lib/gte/reranker.rb
CHANGED
|
@@ -3,15 +3,21 @@
|
|
|
3
3
|
module GTE
|
|
4
4
|
class Reranker
|
|
5
5
|
class << self
|
|
6
|
-
|
|
7
|
-
cfg = default_config(model_dir)
|
|
8
|
-
|
|
9
|
-
if block_given?
|
|
10
|
-
yielded = yield(cfg)
|
|
11
|
-
cfg = yielded if yielded.is_a?(Config::Reranker)
|
|
12
|
-
end
|
|
6
|
+
alias native_new new
|
|
13
7
|
|
|
14
|
-
|
|
8
|
+
def new(model_dir, &block)
|
|
9
|
+
cfg = default_config(model_dir)
|
|
10
|
+
cfg = block.call(cfg) if block
|
|
11
|
+
native_new(
|
|
12
|
+
cfg.model_dir,
|
|
13
|
+
cfg.optimization_level,
|
|
14
|
+
cfg.model_name.to_s,
|
|
15
|
+
cfg.sigmoid,
|
|
16
|
+
cfg.output_tensor.to_s,
|
|
17
|
+
cfg.max_length || 0,
|
|
18
|
+
cfg.padding.to_s,
|
|
19
|
+
cfg.execution_providers.to_s
|
|
20
|
+
)
|
|
15
21
|
end
|
|
16
22
|
|
|
17
23
|
private
|
|
@@ -28,31 +34,6 @@ module GTE
|
|
|
28
34
|
execution_providers: nil
|
|
29
35
|
)
|
|
30
36
|
end
|
|
31
|
-
|
|
32
|
-
def build(cfg)
|
|
33
|
-
new(
|
|
34
|
-
cfg.model_dir,
|
|
35
|
-
cfg.optimization_level,
|
|
36
|
-
cfg.model_name.to_s,
|
|
37
|
-
cfg.sigmoid,
|
|
38
|
-
cfg.output_tensor.to_s,
|
|
39
|
-
cfg.max_length || 0,
|
|
40
|
-
cfg.padding.to_s,
|
|
41
|
-
cfg.execution_providers.to_s,
|
|
42
|
-
false, # lowercase_input
|
|
43
|
-
0 # max_input_chars
|
|
44
|
-
)
|
|
45
|
-
end
|
|
46
|
-
end
|
|
47
|
-
|
|
48
|
-
def rerank(query:, candidates:)
|
|
49
|
-
rows = Array(candidates).map(&:to_s)
|
|
50
|
-
scores = score(query.to_s, rows)
|
|
51
|
-
|
|
52
|
-
rows
|
|
53
|
-
.each_with_index
|
|
54
|
-
.map { |text, idx| { index: idx, score: scores[idx], text: text } }
|
|
55
|
-
.sort_by { |row| -row[:score] }
|
|
56
37
|
end
|
|
57
38
|
end
|
|
58
39
|
end
|
data/lib/gte.rb
CHANGED
|
@@ -14,30 +14,9 @@ require 'gte/model'
|
|
|
14
14
|
require 'gte/reranker'
|
|
15
15
|
|
|
16
16
|
module GTE
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def config(model_dir)
|
|
22
|
-
cfg = Embedder.default_config(model_dir)
|
|
23
|
-
|
|
24
|
-
cfg = yield(cfg) if block_given?
|
|
25
|
-
|
|
26
|
-
@model_cache_mutex.synchronize do
|
|
27
|
-
@model_cache[cache_key(cfg)] ||= Model.new(cfg)
|
|
28
|
-
end
|
|
29
|
-
end
|
|
30
|
-
|
|
31
|
-
def warmup(runner, threads:)
|
|
32
|
-
threads.times.map do
|
|
33
|
-
Thread.new { runner.embed('warmup') }
|
|
34
|
-
end.each(&:join)
|
|
35
|
-
end
|
|
36
|
-
|
|
37
|
-
private
|
|
38
|
-
|
|
39
|
-
def cache_key(cfg)
|
|
40
|
-
cfg.to_h
|
|
41
|
-
end
|
|
17
|
+
def self.config(model_dir, &block)
|
|
18
|
+
cfg = Embedder.default_config(model_dir)
|
|
19
|
+
cfg = block.call(cfg) if block
|
|
20
|
+
Model.new(cfg)
|
|
42
21
|
end
|
|
43
22
|
end
|