gte 0.0.6 → 0.0.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +3 -0
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +1 -1
- data/ext/gte/src/embedder.rs +31 -17
- data/ext/gte/src/model_config.rs +18 -0
- data/ext/gte/src/model_profile.rs +111 -13
- data/ext/gte/src/reranker.rs +42 -21
- data/ext/gte/src/ruby_embedder.rs +39 -20
- data/ext/gte/src/tokenizer.rs +99 -14
- data/ext/gte/tests/inference_integration_test.rs +5 -4
- data/ext/gte/tests/tokenizer_unit_test.rs +5 -2
- data/lib/gte/config.rb +2 -2
- data/lib/gte/embedder.rb +2 -0
- data/lib/gte/reranker.rb +2 -0
- data/lib/gte.rb +1 -0
- metadata +1 -1
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 29659e3ab6072d858b1710a779c3d2e5981f7749782182d141ccd5e9790a1fbb
|
|
4
|
+
data.tar.gz: c42d51cfa1a2ba6a2e83249e8a725c978b11c7ef80c6d69f09a64e884be42031
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: ff2c2b1450a6e82c07aacd2ec98437f03678d56eef9c5516f904021a54f59b2ba5c42b8f6af22b5c4b2dacea98615b99bc54d2c7cdc4e8fbccc1abc195fe9975
|
|
7
|
+
data.tar.gz: 04ca056458d40e2ba7fabcdbcab415a087d54802fb3bd86748dc901c2cf0ecb44072fd1820a73e3dcaca097f165df3e70bab747b38340cd738876af5f0ea7645
|
data/README.md
CHANGED
|
@@ -41,6 +41,7 @@ custom = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
|
41
41
|
config.with(
|
|
42
42
|
output_tensor: "last_hidden_state",
|
|
43
43
|
max_length: 256,
|
|
44
|
+
padding: "batch_longest",
|
|
44
45
|
optimization_level: 3
|
|
45
46
|
)
|
|
46
47
|
end
|
|
@@ -55,6 +56,7 @@ Config fields and defaults:
|
|
|
55
56
|
- `normalize`: `true` (L2 normalization at Ruby-facing API)
|
|
56
57
|
- `output_tensor`: `nil` (auto-select output tensor)
|
|
57
58
|
- `max_length`: `nil` (uses tokenizer/model defaults)
|
|
59
|
+
- `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
|
|
58
60
|
- `execution_providers`: `nil` (falls back to `GTE_EXECUTION_PROVIDERS` / CPU default)
|
|
59
61
|
|
|
60
62
|
Notes:
|
|
@@ -106,6 +108,7 @@ Reranker config fields and defaults:
|
|
|
106
108
|
- `sigmoid`: `false` (set `true` if you want bounded [0,1] style scores)
|
|
107
109
|
- `output_tensor`: `nil`
|
|
108
110
|
- `max_length`: `nil`
|
|
111
|
+
- `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
|
|
109
112
|
- `execution_providers`: `nil`
|
|
110
113
|
|
|
111
114
|
## Runtime + Result Examples
|
data/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.0.
|
|
1
|
+
0.0.7
|
data/ext/gte/Cargo.toml
CHANGED
data/ext/gte/src/embedder.rs
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
|
-
use crate::model_config::{ExtractorMode, ModelConfig};
|
|
2
|
+
use crate::model_config::{ExtractorMode, ModelConfig, ModelLoadOverrides, PaddingMode};
|
|
3
3
|
use crate::model_profile::{
|
|
4
|
-
has_input, infer_extraction_mode,
|
|
5
|
-
resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
|
|
4
|
+
has_input, infer_extraction_mode, read_tokenizer_profile, resolve_default_text_model,
|
|
5
|
+
resolve_named_model, resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
|
|
6
6
|
};
|
|
7
7
|
use crate::postprocess::normalize_l2 as normalize_l2_rows;
|
|
8
8
|
use crate::session::{build_session, run_session};
|
|
9
|
-
use crate::tokenizer::{Tokenized, Tokenizer};
|
|
9
|
+
use crate::tokenizer::{parse_padding_mode_override, Tokenized, Tokenizer};
|
|
10
10
|
use ndarray::Array2;
|
|
11
11
|
use ort::session::Session;
|
|
12
12
|
use std::path::Path;
|
|
@@ -23,7 +23,13 @@ impl Embedder {
|
|
|
23
23
|
P1: AsRef<Path>,
|
|
24
24
|
P2: AsRef<Path>,
|
|
25
25
|
{
|
|
26
|
-
let tokenizer = Tokenizer::new(
|
|
26
|
+
let tokenizer = Tokenizer::new(
|
|
27
|
+
tokenizer_path,
|
|
28
|
+
config.max_length,
|
|
29
|
+
config.with_type_ids,
|
|
30
|
+
config.padding_mode,
|
|
31
|
+
None,
|
|
32
|
+
)?;
|
|
27
33
|
let session = build_session(model_path, &config)?;
|
|
28
34
|
Ok(Self {
|
|
29
35
|
tokenizer,
|
|
@@ -36,10 +42,7 @@ impl Embedder {
|
|
|
36
42
|
dir: P,
|
|
37
43
|
num_threads: usize,
|
|
38
44
|
optimization_level: u8,
|
|
39
|
-
|
|
40
|
-
output_tensor_override: Option<&str>,
|
|
41
|
-
max_length_override: Option<usize>,
|
|
42
|
-
execution_providers_override: Option<&str>,
|
|
45
|
+
overrides: ModelLoadOverrides<'_>,
|
|
43
46
|
) -> Result<Self> {
|
|
44
47
|
const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] = [
|
|
45
48
|
"pooler_output",
|
|
@@ -50,31 +53,35 @@ impl Embedder {
|
|
|
50
53
|
|
|
51
54
|
let dir = dir.as_ref();
|
|
52
55
|
let tokenizer_path = resolve_tokenizer_path(dir)?;
|
|
53
|
-
let model_path = match model_name.filter(|s| !s.is_empty()) {
|
|
56
|
+
let model_path = match overrides.model_name.filter(|s| !s.is_empty()) {
|
|
54
57
|
Some(name) => resolve_named_model(dir, name)?,
|
|
55
58
|
None => resolve_default_text_model(dir)?,
|
|
56
59
|
};
|
|
57
60
|
|
|
58
|
-
let
|
|
61
|
+
let tokenizer_profile = read_tokenizer_profile(dir);
|
|
62
|
+
let max_length = if let Some(override_value) = overrides.max_length {
|
|
59
63
|
if override_value == 0 {
|
|
60
64
|
return Err(GteError::Inference(
|
|
61
65
|
"max_length override must be greater than 0".to_string(),
|
|
62
66
|
));
|
|
63
67
|
}
|
|
64
|
-
override_value
|
|
68
|
+
override_value.min(tokenizer_profile.safe_max_length)
|
|
65
69
|
} else {
|
|
66
|
-
|
|
70
|
+
tokenizer_profile.default_max_length
|
|
67
71
|
};
|
|
72
|
+
let padding_mode =
|
|
73
|
+
parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
|
|
68
74
|
|
|
69
75
|
let session_config = ModelConfig {
|
|
70
76
|
max_length,
|
|
77
|
+
padding_mode,
|
|
71
78
|
output_tensor: String::new(),
|
|
72
79
|
mode: ExtractorMode::Raw,
|
|
73
80
|
with_type_ids: false,
|
|
74
81
|
with_attention_mask: true,
|
|
75
82
|
num_threads,
|
|
76
83
|
optimization_level,
|
|
77
|
-
execution_providers:
|
|
84
|
+
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
78
85
|
};
|
|
79
86
|
let session = build_session(&model_path, &session_config)?;
|
|
80
87
|
|
|
@@ -82,7 +89,7 @@ impl Embedder {
|
|
|
82
89
|
let with_type_ids = has_input(&session, "token_type_ids");
|
|
83
90
|
let with_attention_mask = has_input(&session, "attention_mask");
|
|
84
91
|
let output_tensor =
|
|
85
|
-
select_output_tensor(&session,
|
|
92
|
+
select_output_tensor(&session, overrides.output_tensor, &PREFERRED_EMBEDDING_OUTPUTS)?;
|
|
86
93
|
let mode = infer_extraction_mode(&session, output_tensor.as_str())?;
|
|
87
94
|
if matches!(mode, ExtractorMode::MeanPool) && !with_attention_mask {
|
|
88
95
|
return Err(GteError::Inference(
|
|
@@ -92,16 +99,23 @@ impl Embedder {
|
|
|
92
99
|
|
|
93
100
|
let config = ModelConfig {
|
|
94
101
|
max_length,
|
|
102
|
+
padding_mode,
|
|
95
103
|
output_tensor,
|
|
96
104
|
mode,
|
|
97
105
|
with_type_ids,
|
|
98
106
|
with_attention_mask,
|
|
99
107
|
num_threads,
|
|
100
108
|
optimization_level,
|
|
101
|
-
execution_providers:
|
|
109
|
+
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
102
110
|
};
|
|
103
111
|
|
|
104
|
-
let tokenizer = Tokenizer::new(
|
|
112
|
+
let tokenizer = Tokenizer::new(
|
|
113
|
+
&tokenizer_path,
|
|
114
|
+
config.max_length,
|
|
115
|
+
config.with_type_ids,
|
|
116
|
+
config.padding_mode,
|
|
117
|
+
tokenizer_profile.fixed_padding_length,
|
|
118
|
+
)?;
|
|
105
119
|
|
|
106
120
|
Ok(Self {
|
|
107
121
|
tokenizer,
|
data/ext/gte/src/model_config.rs
CHANGED
|
@@ -5,9 +5,18 @@ pub enum ExtractorMode {
|
|
|
5
5
|
Raw,
|
|
6
6
|
}
|
|
7
7
|
|
|
8
|
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
|
9
|
+
pub enum PaddingMode {
|
|
10
|
+
#[default]
|
|
11
|
+
Auto,
|
|
12
|
+
BatchLongest,
|
|
13
|
+
Fixed,
|
|
14
|
+
}
|
|
15
|
+
|
|
8
16
|
#[derive(Debug, Clone)]
|
|
9
17
|
pub struct ModelConfig {
|
|
10
18
|
pub max_length: usize,
|
|
19
|
+
pub padding_mode: PaddingMode,
|
|
11
20
|
pub output_tensor: String,
|
|
12
21
|
pub mode: ExtractorMode,
|
|
13
22
|
pub with_type_ids: bool,
|
|
@@ -16,3 +25,12 @@ pub struct ModelConfig {
|
|
|
16
25
|
pub optimization_level: u8,
|
|
17
26
|
pub execution_providers: Option<String>,
|
|
18
27
|
}
|
|
28
|
+
|
|
29
|
+
#[derive(Debug, Clone, Copy, Default)]
|
|
30
|
+
pub struct ModelLoadOverrides<'a> {
|
|
31
|
+
pub model_name: Option<&'a str>,
|
|
32
|
+
pub output_tensor: Option<&'a str>,
|
|
33
|
+
pub max_length: Option<usize>,
|
|
34
|
+
pub padding: Option<&'a str>,
|
|
35
|
+
pub execution_providers: Option<&'a str>,
|
|
36
|
+
}
|
|
@@ -1,9 +1,19 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
2
|
use crate::model_config::ExtractorMode;
|
|
3
3
|
use ort::session::Session;
|
|
4
|
+
use serde_json::Value;
|
|
4
5
|
use std::path::{Path, PathBuf};
|
|
5
6
|
|
|
6
7
|
const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
|
|
8
|
+
const DEFAULT_MAX_LENGTH: usize = 512;
|
|
9
|
+
const MAX_SUPPORTED_LENGTH: usize = 8192;
|
|
10
|
+
|
|
11
|
+
#[derive(Debug, Clone, Copy)]
|
|
12
|
+
pub struct TokenizerProfile {
|
|
13
|
+
pub default_max_length: usize,
|
|
14
|
+
pub safe_max_length: usize,
|
|
15
|
+
pub fixed_padding_length: Option<usize>,
|
|
16
|
+
}
|
|
7
17
|
|
|
8
18
|
pub fn resolve_tokenizer_path(dir: &Path) -> Result<PathBuf> {
|
|
9
19
|
let tokenizer_path = dir.join("tokenizer.json");
|
|
@@ -48,19 +58,78 @@ pub fn resolve_default_text_model(dir: &Path) -> Result<PathBuf> {
|
|
|
48
58
|
)))
|
|
49
59
|
}
|
|
50
60
|
|
|
51
|
-
pub fn
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
Some(
|
|
62
|
-
|
|
63
|
-
|
|
61
|
+
pub fn read_tokenizer_profile(dir: &Path) -> TokenizerProfile {
|
|
62
|
+
let tokenizer_config = read_json(dir.join("tokenizer_config.json"));
|
|
63
|
+
let tokenizer_json = read_json(dir.join("tokenizer.json"));
|
|
64
|
+
|
|
65
|
+
let fixed_padding_length = tokenizer_json
|
|
66
|
+
.as_ref()
|
|
67
|
+
.and_then(parse_fixed_padding_length_from_tokenizer_json);
|
|
68
|
+
|
|
69
|
+
let mut candidates = Vec::new();
|
|
70
|
+
if let Some(config) = tokenizer_config.as_ref() {
|
|
71
|
+
if let Some(v) = config.get("max_length").and_then(parse_positive_usize) {
|
|
72
|
+
candidates.push(v.min(MAX_SUPPORTED_LENGTH));
|
|
73
|
+
}
|
|
74
|
+
if let Some(v) = config.get("model_max_length").and_then(parse_positive_usize) {
|
|
75
|
+
candidates.push(v.min(MAX_SUPPORTED_LENGTH));
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
if let Some(tokenizer) = tokenizer_json.as_ref() {
|
|
80
|
+
if let Some(v) = tokenizer
|
|
81
|
+
.get("truncation")
|
|
82
|
+
.and_then(|truncation| truncation.get("max_length"))
|
|
83
|
+
.and_then(parse_positive_usize)
|
|
84
|
+
{
|
|
85
|
+
candidates.push(v.min(MAX_SUPPORTED_LENGTH));
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
if let Some(v) = fixed_padding_length {
|
|
90
|
+
candidates.push(v.min(MAX_SUPPORTED_LENGTH));
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
let default_max_length = candidates
|
|
94
|
+
.iter()
|
|
95
|
+
.copied()
|
|
96
|
+
.min()
|
|
97
|
+
.unwrap_or(DEFAULT_MAX_LENGTH)
|
|
98
|
+
.max(1);
|
|
99
|
+
let safe_max_length = fixed_padding_length.unwrap_or(default_max_length).max(1);
|
|
100
|
+
|
|
101
|
+
TokenizerProfile {
|
|
102
|
+
default_max_length,
|
|
103
|
+
safe_max_length,
|
|
104
|
+
fixed_padding_length,
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
fn read_json(path: PathBuf) -> Option<Value> {
|
|
109
|
+
let contents = std::fs::read_to_string(path).ok()?;
|
|
110
|
+
serde_json::from_str(&contents).ok()
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
fn parse_positive_usize(value: &Value) -> Option<usize> {
|
|
114
|
+
let raw = value
|
|
115
|
+
.as_u64()
|
|
116
|
+
.or_else(|| {
|
|
117
|
+
value
|
|
118
|
+
.as_f64()
|
|
119
|
+
.filter(|&v| v.is_finite() && v > 0.0)
|
|
120
|
+
.map(|v| v as u64)
|
|
121
|
+
})
|
|
122
|
+
.or_else(|| value.as_str().and_then(|s| s.parse::<u64>().ok()))?;
|
|
123
|
+
let parsed = usize::try_from(raw).ok()?;
|
|
124
|
+
(parsed > 0).then_some(parsed)
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
fn parse_fixed_padding_length_from_tokenizer_json(tokenizer_json: &Value) -> Option<usize> {
|
|
128
|
+
tokenizer_json
|
|
129
|
+
.get("padding")
|
|
130
|
+
.and_then(|padding| padding.get("strategy"))
|
|
131
|
+
.and_then(|strategy| strategy.get("Fixed"))
|
|
132
|
+
.and_then(parse_positive_usize)
|
|
64
133
|
}
|
|
65
134
|
|
|
66
135
|
pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Result<()> {
|
|
@@ -177,3 +246,32 @@ pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<E
|
|
|
177
246
|
))),
|
|
178
247
|
}
|
|
179
248
|
}
|
|
249
|
+
|
|
250
|
+
#[cfg(test)]
|
|
251
|
+
mod tests {
|
|
252
|
+
use super::{parse_fixed_padding_length_from_tokenizer_json, parse_positive_usize};
|
|
253
|
+
use serde_json::json;
|
|
254
|
+
|
|
255
|
+
#[test]
|
|
256
|
+
fn parse_positive_usize_handles_integer_float_and_string() {
|
|
257
|
+
assert_eq!(parse_positive_usize(&json!(64)), Some(64));
|
|
258
|
+
assert_eq!(parse_positive_usize(&json!(64.0)), Some(64));
|
|
259
|
+
assert_eq!(parse_positive_usize(&json!("64")), Some(64));
|
|
260
|
+
assert_eq!(parse_positive_usize(&json!(0)), None);
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
#[test]
|
|
264
|
+
fn parse_fixed_padding_length_reads_fixed_padding_strategy() {
|
|
265
|
+
let tokenizer_json = json!({
|
|
266
|
+
"padding": {
|
|
267
|
+
"strategy": {
|
|
268
|
+
"Fixed": 64
|
|
269
|
+
}
|
|
270
|
+
}
|
|
271
|
+
});
|
|
272
|
+
assert_eq!(
|
|
273
|
+
parse_fixed_padding_length_from_tokenizer_json(&tokenizer_json),
|
|
274
|
+
Some(64)
|
|
275
|
+
);
|
|
276
|
+
}
|
|
277
|
+
}
|
data/ext/gte/src/reranker.rs
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
|
+
use crate::model_config::{ModelLoadOverrides, PaddingMode};
|
|
2
3
|
use crate::model_profile::{
|
|
3
|
-
has_input,
|
|
4
|
-
select_output_tensor, validate_supported_text_inputs,
|
|
4
|
+
has_input, read_tokenizer_profile, resolve_default_text_model, resolve_named_model,
|
|
5
|
+
resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
|
|
5
6
|
};
|
|
6
7
|
use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
7
8
|
use crate::postprocess::sigmoid_scores;
|
|
8
9
|
use crate::session::build_session;
|
|
9
|
-
use crate::tokenizer::Tokenizer;
|
|
10
|
-
use ndarray::Array1;
|
|
10
|
+
use crate::tokenizer::{parse_padding_mode_override, Tokenizer};
|
|
11
11
|
use ort::session::Session;
|
|
12
12
|
use std::path::Path;
|
|
13
13
|
|
|
14
14
|
#[derive(Debug, Clone)]
|
|
15
15
|
struct RerankerConfig {
|
|
16
16
|
max_length: usize,
|
|
17
|
+
padding_mode: PaddingMode,
|
|
17
18
|
output_tensor: String,
|
|
18
19
|
with_type_ids: bool,
|
|
19
20
|
with_attention_mask: bool,
|
|
@@ -30,54 +31,62 @@ impl Reranker {
|
|
|
30
31
|
dir: P,
|
|
31
32
|
num_threads: usize,
|
|
32
33
|
optimization_level: u8,
|
|
33
|
-
|
|
34
|
-
output_tensor_override: Option<&str>,
|
|
35
|
-
max_length_override: Option<usize>,
|
|
36
|
-
execution_providers_override: Option<&str>,
|
|
34
|
+
overrides: ModelLoadOverrides<'_>,
|
|
37
35
|
) -> Result<Self> {
|
|
38
36
|
let dir = dir.as_ref();
|
|
39
37
|
let tokenizer_path = resolve_tokenizer_path(dir)?;
|
|
40
|
-
let model_path = match model_name.filter(|s| !s.is_empty()) {
|
|
38
|
+
let model_path = match overrides.model_name.filter(|s| !s.is_empty()) {
|
|
41
39
|
Some(name) => resolve_named_model(dir, name)?,
|
|
42
40
|
None => resolve_default_text_model(dir)?,
|
|
43
41
|
};
|
|
44
42
|
|
|
45
|
-
let
|
|
43
|
+
let tokenizer_profile = read_tokenizer_profile(dir);
|
|
44
|
+
let max_length = if let Some(override_value) = overrides.max_length {
|
|
46
45
|
if override_value == 0 {
|
|
47
46
|
return Err(GteError::Inference(
|
|
48
47
|
"max_length override must be greater than 0".to_string(),
|
|
49
48
|
));
|
|
50
49
|
}
|
|
51
|
-
override_value
|
|
50
|
+
override_value.min(tokenizer_profile.safe_max_length)
|
|
52
51
|
} else {
|
|
53
|
-
|
|
52
|
+
tokenizer_profile.default_max_length
|
|
54
53
|
};
|
|
54
|
+
let padding_mode =
|
|
55
|
+
parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
|
|
55
56
|
|
|
56
57
|
let probe_config = crate::model_config::ModelConfig {
|
|
57
58
|
max_length,
|
|
59
|
+
padding_mode,
|
|
58
60
|
output_tensor: String::new(),
|
|
59
61
|
mode: crate::model_config::ExtractorMode::Raw,
|
|
60
62
|
with_type_ids: false,
|
|
61
63
|
with_attention_mask: true,
|
|
62
64
|
num_threads,
|
|
63
65
|
optimization_level,
|
|
64
|
-
execution_providers:
|
|
66
|
+
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
65
67
|
};
|
|
66
68
|
let session = build_session(&model_path, &probe_config)?;
|
|
67
69
|
|
|
68
70
|
validate_supported_text_inputs(&session, "text reranking")?;
|
|
69
71
|
let with_type_ids = has_input(&session, "token_type_ids");
|
|
70
72
|
let with_attention_mask = has_input(&session, "attention_mask");
|
|
71
|
-
let output_tensor = select_output_tensor(&session,
|
|
73
|
+
let output_tensor = select_output_tensor(&session, overrides.output_tensor, &["logits"])?;
|
|
72
74
|
|
|
73
75
|
let config = RerankerConfig {
|
|
74
76
|
max_length,
|
|
77
|
+
padding_mode,
|
|
75
78
|
output_tensor,
|
|
76
79
|
with_type_ids,
|
|
77
80
|
with_attention_mask,
|
|
78
81
|
};
|
|
79
82
|
|
|
80
|
-
let tokenizer = Tokenizer::new(
|
|
83
|
+
let tokenizer = Tokenizer::new(
|
|
84
|
+
&tokenizer_path,
|
|
85
|
+
config.max_length,
|
|
86
|
+
config.with_type_ids,
|
|
87
|
+
config.padding_mode,
|
|
88
|
+
tokenizer_profile.fixed_padding_length,
|
|
89
|
+
)?;
|
|
81
90
|
|
|
82
91
|
Ok(Self {
|
|
83
92
|
tokenizer,
|
|
@@ -86,14 +95,27 @@ impl Reranker {
|
|
|
86
95
|
})
|
|
87
96
|
}
|
|
88
97
|
|
|
89
|
-
pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<
|
|
98
|
+
pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<Vec<f32>> {
|
|
90
99
|
let tokenized = self.tokenizer.tokenize_pairs(pairs)?;
|
|
91
|
-
|
|
100
|
+
self.score_tokenized(&tokenized, apply_sigmoid)
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
pub fn score(&self, query: &str, candidates: &[String], apply_sigmoid: bool) -> Result<Vec<f32>> {
|
|
104
|
+
let tokenized = self.tokenizer.tokenize_query_candidates(query, candidates)?;
|
|
105
|
+
self.score_tokenized(&tokenized, apply_sigmoid)
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
fn score_tokenized(
|
|
109
|
+
&self,
|
|
110
|
+
tokenized: &crate::tokenizer::Tokenized,
|
|
111
|
+
apply_sigmoid: bool,
|
|
112
|
+
) -> Result<Vec<f32>> {
|
|
113
|
+
let input_tensors = InputTensors::from_tokenized(tokenized, self.config.with_attention_mask)?;
|
|
92
114
|
let outputs = self.session.run(input_tensors.inputs)?;
|
|
93
115
|
let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
|
|
94
116
|
|
|
95
117
|
let mut scores = match array.ndim() {
|
|
96
|
-
1 => array.into_dimensionality::<ndarray::Ix1>()?.
|
|
118
|
+
1 => array.into_dimensionality::<ndarray::Ix1>()?.to_vec(),
|
|
97
119
|
2 => {
|
|
98
120
|
let shape = array.shape();
|
|
99
121
|
if shape[1] == 0 {
|
|
@@ -102,7 +124,7 @@ impl Reranker {
|
|
|
102
124
|
self.config.output_tensor, shape
|
|
103
125
|
)));
|
|
104
126
|
}
|
|
105
|
-
array.slice(ndarray::s![.., 0]).
|
|
127
|
+
array.slice(ndarray::s![.., 0]).to_vec()
|
|
106
128
|
}
|
|
107
129
|
n => {
|
|
108
130
|
return Err(GteError::Inference(format!(
|
|
@@ -113,10 +135,9 @@ impl Reranker {
|
|
|
113
135
|
};
|
|
114
136
|
|
|
115
137
|
if apply_sigmoid {
|
|
116
|
-
sigmoid_scores(scores.
|
|
138
|
+
sigmoid_scores(ndarray::ArrayViewMut1::from(scores.as_mut_slice()));
|
|
117
139
|
}
|
|
118
140
|
|
|
119
141
|
Ok(scores)
|
|
120
142
|
}
|
|
121
|
-
|
|
122
143
|
}
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
use crate::embedder::{normalize_l2, Embedder};
|
|
4
4
|
use crate::error::GteError;
|
|
5
|
+
use crate::model_config::ModelLoadOverrides;
|
|
5
6
|
use crate::reranker::Reranker;
|
|
6
7
|
use magnus::{function, method, prelude::*, wrap, Error, RArray, Ruby};
|
|
7
8
|
use std::os::raw::c_void;
|
|
@@ -38,7 +39,8 @@ unsafe impl Send for InferArgs {}
|
|
|
38
39
|
|
|
39
40
|
struct ScoreArgs {
|
|
40
41
|
reranker: *const Reranker,
|
|
41
|
-
|
|
42
|
+
query: *const String,
|
|
43
|
+
candidates: *const Vec<String>,
|
|
42
44
|
apply_sigmoid: bool,
|
|
43
45
|
result: Option<Result<Vec<f32>, GteError>>,
|
|
44
46
|
}
|
|
@@ -85,13 +87,15 @@ fn infer_without_gvl(
|
|
|
85
87
|
|
|
86
88
|
fn score_without_gvl(
|
|
87
89
|
reranker: &Arc<Reranker>,
|
|
88
|
-
|
|
90
|
+
query: String,
|
|
91
|
+
candidates: Vec<String>,
|
|
89
92
|
apply_sigmoid: bool,
|
|
90
93
|
) -> Result<Vec<f32>, Error> {
|
|
91
94
|
let scores = unsafe {
|
|
92
95
|
let mut args = ScoreArgs {
|
|
93
96
|
reranker: Arc::as_ptr(reranker),
|
|
94
|
-
|
|
97
|
+
query: &query as *const String,
|
|
98
|
+
candidates: &candidates as *const Vec<String>,
|
|
95
99
|
apply_sigmoid,
|
|
96
100
|
result: None,
|
|
97
101
|
};
|
|
@@ -135,8 +139,7 @@ unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
|
135
139
|
unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
136
140
|
let args = &mut *(ptr as *mut ScoreArgs);
|
|
137
141
|
let run_result = catch_unwind(AssertUnwindSafe(|| {
|
|
138
|
-
|
|
139
|
-
Ok(scores.to_vec())
|
|
142
|
+
(*args.reranker).score(&*args.query, &*args.candidates, args.apply_sigmoid)
|
|
140
143
|
}));
|
|
141
144
|
args.result = Some(match run_result {
|
|
142
145
|
Ok(result) => result,
|
|
@@ -171,6 +174,7 @@ impl RbEmbedder {
|
|
|
171
174
|
normalize: bool,
|
|
172
175
|
output_tensor: String,
|
|
173
176
|
max_length: usize,
|
|
177
|
+
padding: String,
|
|
174
178
|
execution_providers: String,
|
|
175
179
|
) -> Result<Self, Error> {
|
|
176
180
|
let name = if model_name.is_empty() {
|
|
@@ -193,14 +197,23 @@ impl RbEmbedder {
|
|
|
193
197
|
} else {
|
|
194
198
|
Some(execution_providers.as_str())
|
|
195
199
|
};
|
|
200
|
+
let padding_override = if padding.is_empty() {
|
|
201
|
+
None
|
|
202
|
+
} else {
|
|
203
|
+
Some(padding.as_str())
|
|
204
|
+
};
|
|
205
|
+
let overrides = ModelLoadOverrides {
|
|
206
|
+
model_name: name,
|
|
207
|
+
output_tensor: output_override,
|
|
208
|
+
max_length: max_length_override,
|
|
209
|
+
padding: padding_override,
|
|
210
|
+
execution_providers: execution_providers_override,
|
|
211
|
+
};
|
|
196
212
|
let embedder = Embedder::from_dir(
|
|
197
213
|
&dir_path,
|
|
198
214
|
num_threads,
|
|
199
215
|
optimization_level,
|
|
200
|
-
|
|
201
|
-
output_override,
|
|
202
|
-
max_length_override,
|
|
203
|
-
execution_providers_override,
|
|
216
|
+
overrides,
|
|
204
217
|
)
|
|
205
218
|
.map_err(magnus::Error::from)?;
|
|
206
219
|
Ok(RbEmbedder {
|
|
@@ -231,6 +244,7 @@ impl RbReranker {
|
|
|
231
244
|
sigmoid: bool,
|
|
232
245
|
output_tensor: String,
|
|
233
246
|
max_length: usize,
|
|
247
|
+
padding: String,
|
|
234
248
|
execution_providers: String,
|
|
235
249
|
) -> Result<Self, Error> {
|
|
236
250
|
let name = if model_name.is_empty() {
|
|
@@ -253,14 +267,23 @@ impl RbReranker {
|
|
|
253
267
|
} else {
|
|
254
268
|
Some(execution_providers.as_str())
|
|
255
269
|
};
|
|
270
|
+
let padding_override = if padding.is_empty() {
|
|
271
|
+
None
|
|
272
|
+
} else {
|
|
273
|
+
Some(padding.as_str())
|
|
274
|
+
};
|
|
275
|
+
let overrides = ModelLoadOverrides {
|
|
276
|
+
model_name: name,
|
|
277
|
+
output_tensor: output_override,
|
|
278
|
+
max_length: max_length_override,
|
|
279
|
+
padding: padding_override,
|
|
280
|
+
execution_providers: execution_providers_override,
|
|
281
|
+
};
|
|
256
282
|
let reranker = Reranker::from_dir(
|
|
257
283
|
&dir_path,
|
|
258
284
|
num_threads,
|
|
259
285
|
optimization_level,
|
|
260
|
-
|
|
261
|
-
output_override,
|
|
262
|
-
max_length_override,
|
|
263
|
-
execution_providers_override,
|
|
286
|
+
overrides,
|
|
264
287
|
)
|
|
265
288
|
.map_err(magnus::Error::from)?;
|
|
266
289
|
Ok(RbReranker {
|
|
@@ -276,11 +299,7 @@ impl RbReranker {
|
|
|
276
299
|
candidates: RArray,
|
|
277
300
|
) -> Result<RArray, Error> {
|
|
278
301
|
let candidates: Vec<String> = candidates.to_vec()?;
|
|
279
|
-
let
|
|
280
|
-
.into_iter()
|
|
281
|
-
.map(|candidate| (query.clone(), candidate))
|
|
282
|
-
.collect();
|
|
283
|
-
let scores = score_without_gvl(&rb_self.inner, pairs, rb_self.sigmoid)?;
|
|
302
|
+
let scores = score_without_gvl(&rb_self.inner, query, candidates, rb_self.sigmoid)?;
|
|
284
303
|
|
|
285
304
|
let out = ruby.ary_new_capa(scores.len());
|
|
286
305
|
for score in scores {
|
|
@@ -376,12 +395,12 @@ impl RbTensor {
|
|
|
376
395
|
pub fn register(ruby: &Ruby) -> Result<(), Error> {
|
|
377
396
|
let module = ruby.define_module("GTE")?;
|
|
378
397
|
let embedder_class = module.define_class("Embedder", ruby.class_object())?;
|
|
379
|
-
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new,
|
|
398
|
+
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 9))?;
|
|
380
399
|
embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
|
|
381
400
|
embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
|
|
382
401
|
|
|
383
402
|
let reranker_class = module.define_class("Reranker", ruby.class_object())?;
|
|
384
|
-
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new,
|
|
403
|
+
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 9))?;
|
|
385
404
|
reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
|
|
386
405
|
|
|
387
406
|
let tensor_class = module.define_class("Tensor", ruby.class_object())?;
|
data/ext/gte/src/tokenizer.rs
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
|
+
use crate::model_config::PaddingMode;
|
|
2
3
|
use std::path::Path;
|
|
3
4
|
use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
|
|
4
5
|
|
|
@@ -20,6 +21,8 @@ impl Tokenizer {
|
|
|
20
21
|
tokenizer_path: P,
|
|
21
22
|
max_length: usize,
|
|
22
23
|
with_type_ids: bool,
|
|
24
|
+
padding_mode: PaddingMode,
|
|
25
|
+
fixed_padding_length: Option<usize>,
|
|
23
26
|
) -> Result<Self> {
|
|
24
27
|
let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
|
|
25
28
|
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
@@ -33,7 +36,7 @@ impl Tokenizer {
|
|
|
33
36
|
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
34
37
|
|
|
35
38
|
let padding = PaddingParams {
|
|
36
|
-
strategy:
|
|
39
|
+
strategy: resolve_padding_strategy(padding_mode, max_length, fixed_padding_length),
|
|
37
40
|
..Default::default()
|
|
38
41
|
};
|
|
39
42
|
tokenizer.with_padding(Some(padding));
|
|
@@ -73,6 +76,56 @@ impl Tokenizer {
|
|
|
73
76
|
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
74
77
|
build_tokenized(&encodings, self.with_type_ids)
|
|
75
78
|
}
|
|
79
|
+
|
|
80
|
+
pub fn tokenize_query_candidates(&self, query: &str, candidates: &[String]) -> Result<Tokenized> {
|
|
81
|
+
let encode_inputs: Vec<tokenizers::EncodeInput<'_>> = candidates
|
|
82
|
+
.iter()
|
|
83
|
+
.map(|candidate| (query, candidate.as_str()).into())
|
|
84
|
+
.collect();
|
|
85
|
+
let encodings = self
|
|
86
|
+
.tokenizer
|
|
87
|
+
.encode_batch_fast(encode_inputs, true)
|
|
88
|
+
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
89
|
+
build_tokenized(&encodings, self.with_type_ids)
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
pub fn parse_padding_mode_override(value: Option<&str>) -> Result<Option<PaddingMode>> {
|
|
94
|
+
let Some(raw) = value.map(str::trim).filter(|v| !v.is_empty()) else {
|
|
95
|
+
return Ok(None);
|
|
96
|
+
};
|
|
97
|
+
|
|
98
|
+
let normalized = raw.to_ascii_lowercase().replace('-', "_");
|
|
99
|
+
let parsed = match normalized.as_str() {
|
|
100
|
+
"auto" => PaddingMode::Auto,
|
|
101
|
+
"batch_longest" | "batchlongest" => PaddingMode::BatchLongest,
|
|
102
|
+
"fixed" => PaddingMode::Fixed,
|
|
103
|
+
_ => {
|
|
104
|
+
return Err(GteError::Inference(format!(
|
|
105
|
+
"invalid padding mode '{}'; expected one of: auto, batch_longest, fixed",
|
|
106
|
+
raw
|
|
107
|
+
)))
|
|
108
|
+
}
|
|
109
|
+
};
|
|
110
|
+
Ok(Some(parsed))
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
fn resolve_padding_strategy(
|
|
114
|
+
padding_mode: PaddingMode,
|
|
115
|
+
max_length: usize,
|
|
116
|
+
fixed_padding_length: Option<usize>,
|
|
117
|
+
) -> PaddingStrategy {
|
|
118
|
+
match padding_mode {
|
|
119
|
+
PaddingMode::BatchLongest => PaddingStrategy::BatchLongest,
|
|
120
|
+
PaddingMode::Fixed => PaddingStrategy::Fixed(max_length),
|
|
121
|
+
PaddingMode::Auto => {
|
|
122
|
+
if fixed_padding_length.is_some() {
|
|
123
|
+
PaddingStrategy::Fixed(max_length)
|
|
124
|
+
} else {
|
|
125
|
+
PaddingStrategy::BatchLongest
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
}
|
|
76
129
|
}
|
|
77
130
|
|
|
78
131
|
fn build_tokenized_single(
|
|
@@ -121,21 +174,17 @@ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> R
|
|
|
121
174
|
let mut type_ids = with_type_ids.then(|| Vec::with_capacity(len));
|
|
122
175
|
|
|
123
176
|
for encoding in encodings {
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
);
|
|
177
|
+
for &value in encoding.get_ids() {
|
|
178
|
+
input_ids.push(i64::from(value));
|
|
179
|
+
}
|
|
180
|
+
for &value in encoding.get_attention_mask() {
|
|
181
|
+
attn_masks.push(i64::from(value));
|
|
182
|
+
}
|
|
131
183
|
|
|
132
184
|
if let Some(type_ids) = type_ids.as_mut() {
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
.iter()
|
|
137
|
-
.map(|&value| i64::from(value)),
|
|
138
|
-
);
|
|
185
|
+
for &value in encoding.get_type_ids() {
|
|
186
|
+
type_ids.push(i64::from(value));
|
|
187
|
+
}
|
|
139
188
|
}
|
|
140
189
|
}
|
|
141
190
|
|
|
@@ -147,3 +196,39 @@ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> R
|
|
|
147
196
|
type_ids,
|
|
148
197
|
})
|
|
149
198
|
}
|
|
199
|
+
|
|
200
|
+
#[cfg(test)]
|
|
201
|
+
mod tests {
|
|
202
|
+
use super::{parse_padding_mode_override, resolve_padding_strategy};
|
|
203
|
+
use crate::model_config::PaddingMode;
|
|
204
|
+
use tokenizers::PaddingStrategy;
|
|
205
|
+
|
|
206
|
+
#[test]
|
|
207
|
+
fn parse_padding_mode_override_accepts_expected_values() {
|
|
208
|
+
assert_eq!(
|
|
209
|
+
parse_padding_mode_override(Some("auto")).unwrap(),
|
|
210
|
+
Some(PaddingMode::Auto)
|
|
211
|
+
);
|
|
212
|
+
assert_eq!(
|
|
213
|
+
parse_padding_mode_override(Some("batch-longest")).unwrap(),
|
|
214
|
+
Some(PaddingMode::BatchLongest)
|
|
215
|
+
);
|
|
216
|
+
assert_eq!(
|
|
217
|
+
parse_padding_mode_override(Some("fixed")).unwrap(),
|
|
218
|
+
Some(PaddingMode::Fixed)
|
|
219
|
+
);
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
#[test]
|
|
223
|
+
fn parse_padding_mode_override_rejects_invalid_values() {
|
|
224
|
+
assert!(parse_padding_mode_override(Some("unknown")).is_err());
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
#[test]
|
|
228
|
+
fn resolve_padding_strategy_uses_fixed_for_auto_when_model_has_fixed_padding() {
|
|
229
|
+
match resolve_padding_strategy(PaddingMode::Auto, 64, Some(64)) {
|
|
230
|
+
PaddingStrategy::Fixed(64) => {}
|
|
231
|
+
other => panic!("expected Fixed(64), got {:?}", other),
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
}
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
use gte::embedder::Embedder;
|
|
2
|
+
use gte::model_config::ModelLoadOverrides;
|
|
2
3
|
|
|
3
4
|
#[test]
|
|
4
5
|
#[ignore = "requires ext/gte/tests/fixtures/e5/tokenizer.json and model.onnx"]
|
|
5
6
|
fn test_e5_single_embedding_shape() {
|
|
6
7
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
7
8
|
|
|
8
|
-
let embedder = Embedder::from_dir(DIR, 0, 3,
|
|
9
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
|
|
9
10
|
.expect("embedder should initialize");
|
|
10
11
|
let result = embedder
|
|
11
12
|
.embed(vec!["query: Hello world".to_string()])
|
|
@@ -20,7 +21,7 @@ fn test_e5_single_embedding_shape() {
|
|
|
20
21
|
fn test_clip_single_embedding_shape() {
|
|
21
22
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/clip");
|
|
22
23
|
|
|
23
|
-
let embedder = Embedder::from_dir(DIR, 0, 3,
|
|
24
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
|
|
24
25
|
.expect("embedder should initialize");
|
|
25
26
|
let result = embedder
|
|
26
27
|
.embed(vec!["a photo of a cat".to_string()])
|
|
@@ -35,7 +36,7 @@ fn test_clip_single_embedding_shape() {
|
|
|
35
36
|
fn test_e5_batch_embedding_shape() {
|
|
36
37
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
37
38
|
|
|
38
|
-
let embedder = Embedder::from_dir(DIR, 0, 3,
|
|
39
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
|
|
39
40
|
.expect("embedder should initialize");
|
|
40
41
|
let texts = vec![
|
|
41
42
|
"query: first sentence".to_string(),
|
|
@@ -54,7 +55,7 @@ fn test_e5_batch_embedding_shape() {
|
|
|
54
55
|
fn test_e5_long_input_truncation_no_error() {
|
|
55
56
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
56
57
|
|
|
57
|
-
let embedder = Embedder::from_dir(DIR, 0, 3,
|
|
58
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
|
|
58
59
|
.expect("embedder should initialize");
|
|
59
60
|
let very_long_text = "word ".repeat(1000);
|
|
60
61
|
let result = embedder
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
use gte::model_config::PaddingMode;
|
|
1
2
|
use gte::tokenizer::Tokenizer;
|
|
2
3
|
|
|
3
4
|
#[test]
|
|
@@ -8,7 +9,8 @@ fn test_e5_tokenizer_output_shape() {
|
|
|
8
9
|
"/tests/fixtures/e5/tokenizer.json"
|
|
9
10
|
);
|
|
10
11
|
|
|
11
|
-
let tokenizer = Tokenizer::new(TOKENIZER, 512, true
|
|
12
|
+
let tokenizer = Tokenizer::new(TOKENIZER, 512, true, PaddingMode::BatchLongest, None)
|
|
13
|
+
.expect("tokenizer should load");
|
|
12
14
|
let texts = vec![
|
|
13
15
|
"Hello, world!".to_string(),
|
|
14
16
|
"A second, longer sentence to test padding behavior.".to_string(),
|
|
@@ -33,7 +35,8 @@ fn test_e5_truncation_at_max_length() {
|
|
|
33
35
|
"/tests/fixtures/e5/tokenizer.json"
|
|
34
36
|
);
|
|
35
37
|
|
|
36
|
-
let tokenizer = Tokenizer::new(TOKENIZER, 16, false
|
|
38
|
+
let tokenizer = Tokenizer::new(TOKENIZER, 16, false, PaddingMode::BatchLongest, None)
|
|
39
|
+
.expect("tokenizer should load");
|
|
37
40
|
let long_text = "word ".repeat(200);
|
|
38
41
|
let tokenized = tokenizer
|
|
39
42
|
.tokenize(&[long_text])
|
data/lib/gte/config.rb
CHANGED
|
@@ -4,12 +4,12 @@ module GTE
|
|
|
4
4
|
module Config
|
|
5
5
|
Text = Data.define(
|
|
6
6
|
:model_dir, :threads, :optimization_level,
|
|
7
|
-
:model_name, :normalize, :output_tensor, :max_length, :execution_providers
|
|
7
|
+
:model_name, :normalize, :output_tensor, :max_length, :padding, :execution_providers
|
|
8
8
|
)
|
|
9
9
|
|
|
10
10
|
Reranker = Data.define(
|
|
11
11
|
:model_dir, :threads, :optimization_level,
|
|
12
|
-
:model_name, :sigmoid, :output_tensor, :max_length, :execution_providers
|
|
12
|
+
:model_name, :sigmoid, :output_tensor, :max_length, :padding, :execution_providers
|
|
13
13
|
)
|
|
14
14
|
end
|
|
15
15
|
end
|
data/lib/gte/embedder.rb
CHANGED
|
@@ -18,6 +18,7 @@ module GTE
|
|
|
18
18
|
config.normalize,
|
|
19
19
|
config.output_tensor.to_s,
|
|
20
20
|
config.max_length || 0,
|
|
21
|
+
config.padding.to_s,
|
|
21
22
|
config.execution_providers.to_s
|
|
22
23
|
)
|
|
23
24
|
end
|
|
@@ -33,6 +34,7 @@ module GTE
|
|
|
33
34
|
normalize: true,
|
|
34
35
|
output_tensor: nil,
|
|
35
36
|
max_length: nil,
|
|
37
|
+
padding: nil,
|
|
36
38
|
execution_providers: nil
|
|
37
39
|
)
|
|
38
40
|
end
|
data/lib/gte/reranker.rb
CHANGED
|
@@ -25,6 +25,7 @@ module GTE
|
|
|
25
25
|
sigmoid: false,
|
|
26
26
|
output_tensor: nil,
|
|
27
27
|
max_length: nil,
|
|
28
|
+
padding: nil,
|
|
28
29
|
execution_providers: nil
|
|
29
30
|
)
|
|
30
31
|
end
|
|
@@ -38,6 +39,7 @@ module GTE
|
|
|
38
39
|
cfg.sigmoid,
|
|
39
40
|
cfg.output_tensor.to_s,
|
|
40
41
|
cfg.max_length || 0,
|
|
42
|
+
cfg.padding.to_s,
|
|
41
43
|
cfg.execution_providers.to_s
|
|
42
44
|
)
|
|
43
45
|
end
|
data/lib/gte.rb
CHANGED