gte 0.0.13 → 0.0.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +93 -27
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +26 -4
- data/ext/gte/benches/hot_path.rs +20 -54
- data/ext/gte/build.rs +2 -6
- data/ext/gte/rustfmt.toml +5 -0
- data/ext/gte/src/embedder.rs +71 -43
- data/ext/gte/src/error.rs +4 -4
- data/ext/gte/src/lib.rs +1 -1
- data/ext/gte/src/model_config.rs +4 -0
- data/ext/gte/src/model_profile.rs +26 -87
- data/ext/gte/src/pipeline.rs +11 -30
- data/ext/gte/src/postprocess.rs +8 -14
- data/ext/gte/src/reranker.rs +50 -50
- data/ext/gte/src/ruby_embedder.rs +48 -53
- data/ext/gte/src/session.rs +136 -248
- data/ext/gte/src/tokenizer.rs +51 -125
- data/ext/gte/tests/inference_integration_test.rs +8 -18
- data/ext/gte/tests/padding_regression_test.rs +13 -26
- data/ext/gte/tests/tokenizer_unit_test.rs +10 -24
- data/lib/gte/config.rb +2 -1
- data/lib/gte/embedder.rb +6 -2
- data/lib/gte/reranker.rb +3 -1
- data/lib/gte.rb +6 -0
- metadata +2 -1
data/ext/gte/src/tokenizer.rs
CHANGED
|
@@ -24,69 +24,51 @@ impl Tokenizer {
|
|
|
24
24
|
padding_mode: PaddingMode,
|
|
25
25
|
fixed_padding_length: Option<usize>,
|
|
26
26
|
) -> Result<Self> {
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
..Default::default()
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
}
|
|
42
|
-
tokenizer.with_padding(Some(padding));
|
|
43
|
-
|
|
44
|
-
Ok(Self {
|
|
45
|
-
tokenizer,
|
|
46
|
-
with_type_ids,
|
|
47
|
-
})
|
|
27
|
+
#[allow(unused_results)]
|
|
28
|
+
{
|
|
29
|
+
let mut tokenizer =
|
|
30
|
+
tokenizers::Tokenizer::from_file(tokenizer_path).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
31
|
+
|
|
32
|
+
let truncation = TruncationParams { max_length, ..Default::default() };
|
|
33
|
+
let padding = PaddingParams {
|
|
34
|
+
strategy: resolve_padding_strategy(padding_mode, max_length, fixed_padding_length),
|
|
35
|
+
..Default::default()
|
|
36
|
+
};
|
|
37
|
+
tokenizer.with_truncation(Some(truncation)).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
38
|
+
tokenizer.with_padding(Some(padding));
|
|
39
|
+
|
|
40
|
+
Ok(Self { tokenizer, with_type_ids })
|
|
41
|
+
}
|
|
48
42
|
}
|
|
49
43
|
|
|
50
44
|
pub fn tokenize(&self, texts: &[String]) -> Result<Tokenized> {
|
|
51
45
|
if texts.len() == 1 {
|
|
52
|
-
let encoding =
|
|
53
|
-
.tokenizer
|
|
54
|
-
|
|
55
|
-
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
56
|
-
return build_tokenized_single(&encoding, self.with_type_ids);
|
|
46
|
+
let encoding =
|
|
47
|
+
self.tokenizer.encode_fast(texts[0].as_str(), true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
48
|
+
return Ok(build_tokenized_single(&encoding, self.with_type_ids));
|
|
57
49
|
}
|
|
58
50
|
|
|
59
51
|
let encode_inputs: Vec<&str> = texts.iter().map(String::as_str).collect();
|
|
60
|
-
let encodings =
|
|
61
|
-
.tokenizer
|
|
62
|
-
.encode_batch_fast(encode_inputs, true)
|
|
63
|
-
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
52
|
+
let encodings =
|
|
53
|
+
self.tokenizer.encode_batch_fast(encode_inputs, true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
64
54
|
|
|
65
|
-
build_tokenized(&encodings, self.with_type_ids)
|
|
55
|
+
Ok(build_tokenized(&encodings, self.with_type_ids))
|
|
66
56
|
}
|
|
67
57
|
|
|
68
58
|
pub fn tokenize_pairs(&self, pairs: &[(String, String)]) -> Result<Tokenized> {
|
|
69
|
-
let encode_inputs: Vec<tokenizers::EncodeInput<'_>> =
|
|
70
|
-
.iter()
|
|
71
|
-
|
|
72
|
-
.
|
|
73
|
-
|
|
74
|
-
.tokenizer
|
|
75
|
-
.encode_batch_fast(encode_inputs, true)
|
|
76
|
-
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
77
|
-
build_tokenized(&encodings, self.with_type_ids)
|
|
59
|
+
let encode_inputs: Vec<tokenizers::EncodeInput<'_>> =
|
|
60
|
+
pairs.iter().map(|(left, right)| (left.as_str(), right.as_str()).into()).collect();
|
|
61
|
+
let encodings =
|
|
62
|
+
self.tokenizer.encode_batch_fast(encode_inputs, true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
63
|
+
Ok(build_tokenized(&encodings, self.with_type_ids))
|
|
78
64
|
}
|
|
79
65
|
|
|
80
66
|
pub fn tokenize_query_candidates(&self, query: &str, candidates: &[String]) -> Result<Tokenized> {
|
|
81
|
-
let encode_inputs: Vec<tokenizers::EncodeInput<'_>> =
|
|
82
|
-
.iter()
|
|
83
|
-
|
|
84
|
-
.
|
|
85
|
-
|
|
86
|
-
.tokenizer
|
|
87
|
-
.encode_batch_fast(encode_inputs, true)
|
|
88
|
-
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
89
|
-
build_tokenized(&encodings, self.with_type_ids)
|
|
67
|
+
let encode_inputs: Vec<tokenizers::EncodeInput<'_>> =
|
|
68
|
+
candidates.iter().map(|candidate| (query, candidate.as_str()).into()).collect();
|
|
69
|
+
let encodings =
|
|
70
|
+
self.tokenizer.encode_batch_fast(encode_inputs, true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
71
|
+
Ok(build_tokenized(&encodings, self.with_type_ids))
|
|
90
72
|
}
|
|
91
73
|
}
|
|
92
74
|
|
|
@@ -102,8 +84,7 @@ pub fn parse_padding_mode_override(value: Option<&str>) -> Result<Option<Padding
|
|
|
102
84
|
"fixed" => PaddingMode::Fixed,
|
|
103
85
|
_ => {
|
|
104
86
|
return Err(GteError::Inference(format!(
|
|
105
|
-
"invalid padding mode '{}'; expected one of: auto, batch_longest, fixed"
|
|
106
|
-
raw
|
|
87
|
+
"invalid padding mode '{raw}'; expected one of: auto, batch_longest, fixed"
|
|
107
88
|
)))
|
|
108
89
|
}
|
|
109
90
|
};
|
|
@@ -121,45 +102,20 @@ fn resolve_padding_strategy(
|
|
|
121
102
|
}
|
|
122
103
|
}
|
|
123
104
|
|
|
124
|
-
fn build_tokenized_single(
|
|
125
|
-
encoding: &tokenizers::Encoding,
|
|
126
|
-
with_type_ids: bool,
|
|
127
|
-
) -> Result<Tokenized> {
|
|
105
|
+
fn build_tokenized_single(encoding: &tokenizers::Encoding, with_type_ids: bool) -> Tokenized {
|
|
128
106
|
let cols = encoding.len();
|
|
129
107
|
|
|
130
|
-
let input_ids: Vec<i64> = encoding
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
.map(|&
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
.get_attention_mask()
|
|
137
|
-
.iter()
|
|
138
|
-
.map(|&value| i64::from(value))
|
|
139
|
-
.collect();
|
|
140
|
-
let type_ids: Option<Vec<i64>> = with_type_ids.then(|| {
|
|
141
|
-
encoding
|
|
142
|
-
.get_type_ids()
|
|
143
|
-
.iter()
|
|
144
|
-
.map(|&value| i64::from(value))
|
|
145
|
-
.collect()
|
|
146
|
-
});
|
|
147
|
-
|
|
148
|
-
Ok(Tokenized {
|
|
149
|
-
rows: 1,
|
|
150
|
-
cols,
|
|
151
|
-
input_ids,
|
|
152
|
-
attn_masks,
|
|
153
|
-
type_ids,
|
|
154
|
-
})
|
|
108
|
+
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&v| i64::from(v)).collect();
|
|
109
|
+
let attn_masks: Vec<i64> = encoding.get_attention_mask().iter().map(|&v| i64::from(v)).collect();
|
|
110
|
+
let type_ids: Option<Vec<i64>> =
|
|
111
|
+
with_type_ids.then(|| encoding.get_type_ids().iter().map(|&v| i64::from(v)).collect());
|
|
112
|
+
|
|
113
|
+
Tokenized { rows: 1, cols, input_ids, attn_masks, type_ids }
|
|
155
114
|
}
|
|
156
115
|
|
|
157
|
-
fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) ->
|
|
116
|
+
fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> Tokenized {
|
|
158
117
|
let rows = encodings.len();
|
|
159
|
-
let cols = encodings
|
|
160
|
-
.first()
|
|
161
|
-
.map(|encoding| encoding.len())
|
|
162
|
-
.unwrap_or(0);
|
|
118
|
+
let cols = encodings.first().map_or(0, tokenizers::Encoding::len);
|
|
163
119
|
let len = rows * cols;
|
|
164
120
|
|
|
165
121
|
let mut input_ids = Vec::with_capacity(len);
|
|
@@ -167,27 +123,15 @@ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> R
|
|
|
167
123
|
let mut type_ids = with_type_ids.then(|| Vec::with_capacity(len));
|
|
168
124
|
|
|
169
125
|
for encoding in encodings {
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
}
|
|
173
|
-
for &value in encoding.get_attention_mask() {
|
|
174
|
-
attn_masks.push(i64::from(value));
|
|
175
|
-
}
|
|
126
|
+
input_ids.extend(encoding.get_ids().iter().map(|&v| i64::from(v)));
|
|
127
|
+
attn_masks.extend(encoding.get_attention_mask().iter().map(|&v| i64::from(v)));
|
|
176
128
|
|
|
177
129
|
if let Some(type_ids) = type_ids.as_mut() {
|
|
178
|
-
|
|
179
|
-
type_ids.push(i64::from(value));
|
|
180
|
-
}
|
|
130
|
+
type_ids.extend(encoding.get_type_ids().iter().map(|&v| i64::from(v)));
|
|
181
131
|
}
|
|
182
132
|
}
|
|
183
133
|
|
|
184
|
-
|
|
185
|
-
rows,
|
|
186
|
-
cols,
|
|
187
|
-
input_ids,
|
|
188
|
-
attn_masks,
|
|
189
|
-
type_ids,
|
|
190
|
-
})
|
|
134
|
+
Tokenized { rows, cols, input_ids, attn_masks, type_ids }
|
|
191
135
|
}
|
|
192
136
|
|
|
193
137
|
#[cfg(test)]
|
|
@@ -198,18 +142,9 @@ mod tests {
|
|
|
198
142
|
|
|
199
143
|
#[test]
|
|
200
144
|
fn parse_padding_mode_override_accepts_expected_values() {
|
|
201
|
-
assert_eq!(
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
);
|
|
205
|
-
assert_eq!(
|
|
206
|
-
parse_padding_mode_override(Some("batch-longest")).unwrap(),
|
|
207
|
-
Some(PaddingMode::BatchLongest)
|
|
208
|
-
);
|
|
209
|
-
assert_eq!(
|
|
210
|
-
parse_padding_mode_override(Some("fixed")).unwrap(),
|
|
211
|
-
Some(PaddingMode::Fixed)
|
|
212
|
-
);
|
|
145
|
+
assert_eq!(parse_padding_mode_override(Some("auto")).unwrap(), Some(PaddingMode::Auto));
|
|
146
|
+
assert_eq!(parse_padding_mode_override(Some("batch-longest")).unwrap(), Some(PaddingMode::BatchLongest));
|
|
147
|
+
assert_eq!(parse_padding_mode_override(Some("fixed")).unwrap(), Some(PaddingMode::Fixed));
|
|
213
148
|
}
|
|
214
149
|
|
|
215
150
|
#[test]
|
|
@@ -222,21 +157,12 @@ mod tests {
|
|
|
222
157
|
// Auto ignores fixed_padding_length from tokenizer.json — BatchLongest is
|
|
223
158
|
// always faster for inference and correct for variable-length inputs.
|
|
224
159
|
// Use PaddingMode::Fixed explicitly when fixed-length padding is required.
|
|
225
|
-
assert!(matches!(
|
|
226
|
-
|
|
227
|
-
PaddingStrategy::BatchLongest
|
|
228
|
-
));
|
|
229
|
-
assert!(matches!(
|
|
230
|
-
resolve_padding_strategy(PaddingMode::Auto, 512, None),
|
|
231
|
-
PaddingStrategy::BatchLongest
|
|
232
|
-
));
|
|
160
|
+
assert!(matches!(resolve_padding_strategy(PaddingMode::Auto, 64, Some(64)), PaddingStrategy::BatchLongest));
|
|
161
|
+
assert!(matches!(resolve_padding_strategy(PaddingMode::Auto, 512, None), PaddingStrategy::BatchLongest));
|
|
233
162
|
}
|
|
234
163
|
|
|
235
164
|
#[test]
|
|
236
165
|
fn resolve_padding_strategy_fixed_uses_max_length() {
|
|
237
|
-
assert!(matches!(
|
|
238
|
-
resolve_padding_strategy(PaddingMode::Fixed, 64, None),
|
|
239
|
-
PaddingStrategy::Fixed(64)
|
|
240
|
-
));
|
|
166
|
+
assert!(matches!(resolve_padding_strategy(PaddingMode::Fixed, 64, None), PaddingStrategy::Fixed(64)));
|
|
241
167
|
}
|
|
242
168
|
}
|
|
@@ -8,11 +8,8 @@ fn model_dir(env_var: &str) -> Option<String> {
|
|
|
8
8
|
#[test]
|
|
9
9
|
fn test_e5_single_embedding_shape() {
|
|
10
10
|
let Some(dir) = model_dir("GTE_BENCH_E5_DIR") else { return };
|
|
11
|
-
let embedder = Embedder::from_dir(&dir, 0, ModelLoadOverrides::default())
|
|
12
|
-
|
|
13
|
-
let result = embedder
|
|
14
|
-
.embed(vec!["query: Hello world".to_string()])
|
|
15
|
-
.expect("embed should succeed");
|
|
11
|
+
let embedder = Embedder::from_dir(&dir, 0, ModelLoadOverrides::default()).expect("embedder should initialize");
|
|
12
|
+
let result = embedder.embed(&["query: Hello world".to_string()]).expect("embed should succeed");
|
|
16
13
|
|
|
17
14
|
assert_eq!(result.shape()[0], 1);
|
|
18
15
|
assert!(result.shape()[1] > 0);
|
|
@@ -21,11 +18,8 @@ fn test_e5_single_embedding_shape() {
|
|
|
21
18
|
#[test]
|
|
22
19
|
fn test_clip_single_embedding_shape() {
|
|
23
20
|
let Some(dir) = model_dir("GTE_BENCH_CLIP_DIR") else { return };
|
|
24
|
-
let embedder = Embedder::from_dir(&dir, 0, ModelLoadOverrides::default())
|
|
25
|
-
|
|
26
|
-
let result = embedder
|
|
27
|
-
.embed(vec!["a photo of a cat".to_string()])
|
|
28
|
-
.expect("embed should succeed");
|
|
21
|
+
let embedder = Embedder::from_dir(&dir, 0, ModelLoadOverrides::default()).expect("embedder should initialize");
|
|
22
|
+
let result = embedder.embed(&["a photo of a cat".to_string()]).expect("embed should succeed");
|
|
29
23
|
|
|
30
24
|
assert_eq!(result.shape()[0], 1);
|
|
31
25
|
assert!(result.shape()[1] > 0);
|
|
@@ -34,15 +28,14 @@ fn test_clip_single_embedding_shape() {
|
|
|
34
28
|
#[test]
|
|
35
29
|
fn test_e5_batch_embedding_shape() {
|
|
36
30
|
let Some(dir) = model_dir("GTE_BENCH_E5_DIR") else { return };
|
|
37
|
-
let embedder = Embedder::from_dir(&dir, 0, ModelLoadOverrides::default())
|
|
38
|
-
.expect("embedder should initialize");
|
|
31
|
+
let embedder = Embedder::from_dir(&dir, 0, ModelLoadOverrides::default()).expect("embedder should initialize");
|
|
39
32
|
let texts = vec![
|
|
40
33
|
"query: first sentence".to_string(),
|
|
41
34
|
"query: second sentence".to_string(),
|
|
42
35
|
"query: third sentence for batch".to_string(),
|
|
43
36
|
];
|
|
44
37
|
|
|
45
|
-
let result = embedder.embed(texts).expect("batch embed should succeed");
|
|
38
|
+
let result = embedder.embed(&texts).expect("batch embed should succeed");
|
|
46
39
|
|
|
47
40
|
assert_eq!(result.shape()[0], 3);
|
|
48
41
|
assert!(result.shape()[1] > 0);
|
|
@@ -51,12 +44,9 @@ fn test_e5_batch_embedding_shape() {
|
|
|
51
44
|
#[test]
|
|
52
45
|
fn test_e5_long_input_truncation_no_error() {
|
|
53
46
|
let Some(dir) = model_dir("GTE_BENCH_E5_DIR") else { return };
|
|
54
|
-
let embedder = Embedder::from_dir(&dir, 0, ModelLoadOverrides::default())
|
|
55
|
-
.expect("embedder should initialize");
|
|
47
|
+
let embedder = Embedder::from_dir(&dir, 0, ModelLoadOverrides::default()).expect("embedder should initialize");
|
|
56
48
|
let very_long_text = "word ".repeat(1000);
|
|
57
|
-
let result = embedder
|
|
58
|
-
.embed(vec![very_long_text])
|
|
59
|
-
.expect("long input should be truncated without error");
|
|
49
|
+
let result = embedder.embed(&[very_long_text]).expect("long input should be truncated without error");
|
|
60
50
|
|
|
61
51
|
assert_eq!(result.shape()[0], 1);
|
|
62
52
|
assert!(result.shape()[1] > 0);
|
|
@@ -12,10 +12,7 @@
|
|
|
12
12
|
use gte::model_config::PaddingMode;
|
|
13
13
|
use gte::tokenizer::Tokenizer;
|
|
14
14
|
|
|
15
|
-
const TOKENIZER: &str = concat!(
|
|
16
|
-
env!("CARGO_MANIFEST_DIR"),
|
|
17
|
-
"/tests/fixtures/minimal/tokenizer.json"
|
|
18
|
-
);
|
|
15
|
+
const TOKENIZER: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/minimal/tokenizer.json");
|
|
19
16
|
|
|
20
17
|
// Short input tokenizes to 1 token with this vocabulary.
|
|
21
18
|
const SHORT_INPUT: &str = "cat";
|
|
@@ -28,9 +25,7 @@ fn auto_padding_uses_batch_longest_regardless_of_tokenizer_json() {
|
|
|
28
25
|
let tokenizer = Tokenizer::new(TOKENIZER, MAX_LENGTH, false, PaddingMode::Auto, Some(MAX_LENGTH))
|
|
29
26
|
.expect("tokenizer should load");
|
|
30
27
|
|
|
31
|
-
let tokenized = tokenizer
|
|
32
|
-
.tokenize(&[SHORT_INPUT.to_string()])
|
|
33
|
-
.expect("tokenize should succeed");
|
|
28
|
+
let tokenized = tokenizer.tokenize(&[SHORT_INPUT.to_string()]).expect("tokenize should succeed");
|
|
34
29
|
|
|
35
30
|
// Old behavior: cols == 64 (silently padded to max_length)
|
|
36
31
|
// New behavior: cols == actual token count (1 for "cat")
|
|
@@ -46,30 +41,24 @@ fn auto_padding_uses_batch_longest_regardless_of_tokenizer_json() {
|
|
|
46
41
|
|
|
47
42
|
#[test]
|
|
48
43
|
fn fixed_padding_mode_pads_to_max_length() {
|
|
49
|
-
let tokenizer =
|
|
50
|
-
.expect("tokenizer should load");
|
|
44
|
+
let tokenizer =
|
|
45
|
+
Tokenizer::new(TOKENIZER, MAX_LENGTH, false, PaddingMode::Fixed, None).expect("tokenizer should load");
|
|
51
46
|
|
|
52
|
-
let tokenized = tokenizer
|
|
53
|
-
.tokenize(&[SHORT_INPUT.to_string()])
|
|
54
|
-
.expect("tokenize should succeed");
|
|
47
|
+
let tokenized = tokenizer.tokenize(&[SHORT_INPUT.to_string()]).expect("tokenize should succeed");
|
|
55
48
|
|
|
56
|
-
assert_eq!(
|
|
57
|
-
tokenized.cols, MAX_LENGTH,
|
|
58
|
-
"Fixed mode should pad to max_length"
|
|
59
|
-
);
|
|
49
|
+
assert_eq!(tokenized.cols, MAX_LENGTH, "Fixed mode should pad to max_length");
|
|
60
50
|
assert_eq!(tokenized.input_ids.len(), MAX_LENGTH);
|
|
61
51
|
assert_eq!(tokenized.attn_masks.len(), MAX_LENGTH);
|
|
62
52
|
}
|
|
63
53
|
|
|
64
54
|
#[test]
|
|
65
55
|
fn batch_longest_padding_uses_longest_sequence_in_batch() {
|
|
66
|
-
let tokenizer =
|
|
67
|
-
.expect("tokenizer should load");
|
|
56
|
+
let tokenizer =
|
|
57
|
+
Tokenizer::new(TOKENIZER, MAX_LENGTH, false, PaddingMode::BatchLongest, None).expect("tokenizer should load");
|
|
68
58
|
|
|
69
59
|
// "cat" = 1 token, "hello world" = 2 tokens — batch pads to 2, not 64
|
|
70
|
-
let tokenized =
|
|
71
|
-
.tokenize(&["cat".to_string(), "hello world".to_string()])
|
|
72
|
-
.expect("tokenize should succeed");
|
|
60
|
+
let tokenized =
|
|
61
|
+
tokenizer.tokenize(&["cat".to_string(), "hello world".to_string()]).expect("tokenize should succeed");
|
|
73
62
|
|
|
74
63
|
assert_eq!(tokenized.rows, 2);
|
|
75
64
|
assert!(
|
|
@@ -83,12 +72,10 @@ fn batch_longest_padding_uses_longest_sequence_in_batch() {
|
|
|
83
72
|
#[test]
|
|
84
73
|
fn auto_padding_with_no_fixed_hint_also_uses_batch_longest() {
|
|
85
74
|
// Sanity check: Auto with fixed_padding_length=None also uses BatchLongest
|
|
86
|
-
let tokenizer =
|
|
87
|
-
.expect("tokenizer should load");
|
|
75
|
+
let tokenizer =
|
|
76
|
+
Tokenizer::new(TOKENIZER, MAX_LENGTH, false, PaddingMode::Auto, None).expect("tokenizer should load");
|
|
88
77
|
|
|
89
|
-
let tokenized = tokenizer
|
|
90
|
-
.tokenize(&[SHORT_INPUT.to_string()])
|
|
91
|
-
.expect("tokenize should succeed");
|
|
78
|
+
let tokenized = tokenizer.tokenize(&[SHORT_INPUT.to_string()]).expect("tokenize should succeed");
|
|
92
79
|
|
|
93
80
|
assert!(tokenized.cols < MAX_LENGTH);
|
|
94
81
|
}
|
|
@@ -4,17 +4,11 @@ use gte::tokenizer::Tokenizer;
|
|
|
4
4
|
#[test]
|
|
5
5
|
#[ignore = "requires ext/gte/tests/fixtures/e5/tokenizer.json"]
|
|
6
6
|
fn test_e5_tokenizer_output_shape() {
|
|
7
|
-
const TOKENIZER: &str = concat!(
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
let tokenizer = Tokenizer::new(TOKENIZER, 512, true, PaddingMode::BatchLongest, None)
|
|
13
|
-
.expect("tokenizer should load");
|
|
14
|
-
let texts = vec![
|
|
15
|
-
"Hello, world!".to_string(),
|
|
16
|
-
"A second, longer sentence to test padding behavior.".to_string(),
|
|
17
|
-
];
|
|
7
|
+
const TOKENIZER: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5/tokenizer.json");
|
|
8
|
+
|
|
9
|
+
let tokenizer =
|
|
10
|
+
Tokenizer::new(TOKENIZER, 512, true, PaddingMode::BatchLongest, None).expect("tokenizer should load");
|
|
11
|
+
let texts = vec!["Hello, world!".to_string(), "A second, longer sentence to test padding behavior.".to_string()];
|
|
18
12
|
|
|
19
13
|
let tokenized = tokenizer.tokenize(&texts).expect("tokenize should succeed");
|
|
20
14
|
|
|
@@ -30,21 +24,13 @@ fn test_e5_tokenizer_output_shape() {
|
|
|
30
24
|
#[test]
|
|
31
25
|
#[ignore = "requires ext/gte/tests/fixtures/e5/tokenizer.json"]
|
|
32
26
|
fn test_e5_truncation_at_max_length() {
|
|
33
|
-
const TOKENIZER: &str = concat!(
|
|
34
|
-
env!("CARGO_MANIFEST_DIR"),
|
|
35
|
-
"/tests/fixtures/e5/tokenizer.json"
|
|
36
|
-
);
|
|
27
|
+
const TOKENIZER: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5/tokenizer.json");
|
|
37
28
|
|
|
38
|
-
let tokenizer =
|
|
39
|
-
.expect("tokenizer should load");
|
|
29
|
+
let tokenizer =
|
|
30
|
+
Tokenizer::new(TOKENIZER, 16, false, PaddingMode::BatchLongest, None).expect("tokenizer should load");
|
|
40
31
|
let long_text = "word ".repeat(200);
|
|
41
|
-
let tokenized = tokenizer
|
|
42
|
-
.tokenize(&[long_text])
|
|
43
|
-
.expect("tokenize should not error on long input");
|
|
32
|
+
let tokenized = tokenizer.tokenize(&[long_text]).expect("tokenize should not error on long input");
|
|
44
33
|
|
|
45
34
|
assert_eq!(tokenized.rows, 1);
|
|
46
|
-
assert_eq!(
|
|
47
|
-
tokenized.cols, 16,
|
|
48
|
-
"sequence length should be truncated to max_length"
|
|
49
|
-
);
|
|
35
|
+
assert_eq!(tokenized.cols, 16, "sequence length should be truncated to max_length");
|
|
50
36
|
}
|
data/lib/gte/config.rb
CHANGED
|
@@ -4,7 +4,8 @@ module GTE
|
|
|
4
4
|
module Config
|
|
5
5
|
Text = Data.define(
|
|
6
6
|
:model_dir, :optimization_level,
|
|
7
|
-
:model_name, :normalize, :output_tensor, :max_length, :padding, :execution_providers
|
|
7
|
+
:model_name, :normalize, :output_tensor, :max_length, :padding, :execution_providers,
|
|
8
|
+
:lowercase_input, :max_input_chars
|
|
8
9
|
)
|
|
9
10
|
|
|
10
11
|
Reranker = Data.define(
|
data/lib/gte/embedder.rb
CHANGED
|
@@ -20,7 +20,9 @@ module GTE
|
|
|
20
20
|
config.output_tensor.to_s,
|
|
21
21
|
config.max_length || 0,
|
|
22
22
|
config.padding.to_s,
|
|
23
|
-
config.execution_providers.to_s
|
|
23
|
+
config.execution_providers.to_s,
|
|
24
|
+
config.lowercase_input ? true : false,
|
|
25
|
+
config.max_input_chars || 0
|
|
24
26
|
)
|
|
25
27
|
end
|
|
26
28
|
|
|
@@ -33,7 +35,9 @@ module GTE
|
|
|
33
35
|
output_tensor: nil,
|
|
34
36
|
max_length: nil,
|
|
35
37
|
padding: nil,
|
|
36
|
-
execution_providers: nil
|
|
38
|
+
execution_providers: nil,
|
|
39
|
+
lowercase_input: false,
|
|
40
|
+
max_input_chars: nil
|
|
37
41
|
)
|
|
38
42
|
end
|
|
39
43
|
end
|
data/lib/gte/reranker.rb
CHANGED
data/lib/gte.rb
CHANGED
metadata
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: gte
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.0.
|
|
4
|
+
version: 0.0.14
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- elcuervo
|
|
@@ -95,6 +95,7 @@ files:
|
|
|
95
95
|
- ext/gte/benches/hot_path.rs
|
|
96
96
|
- ext/gte/build.rs
|
|
97
97
|
- ext/gte/extconf.rb
|
|
98
|
+
- ext/gte/rustfmt.toml
|
|
98
99
|
- ext/gte/src/embedder.rs
|
|
99
100
|
- ext/gte/src/error.rs
|
|
100
101
|
- ext/gte/src/lib.rs
|