gte 0.0.5 → 0.0.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +37 -3
- data/Rakefile +2 -2
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +1 -1
- data/ext/gte/src/embedder.rs +31 -14
- data/ext/gte/src/model_config.rs +19 -0
- data/ext/gte/src/model_profile.rs +111 -13
- data/ext/gte/src/reranker.rs +42 -19
- data/ext/gte/src/ruby_embedder.rs +51 -18
- data/ext/gte/src/session.rs +67 -7
- data/ext/gte/src/tokenizer.rs +99 -14
- data/ext/gte/tests/inference_integration_test.rs +9 -8
- data/ext/gte/tests/tokenizer_unit_test.rs +5 -2
- data/lib/gte/config.rb +2 -2
- data/lib/gte/embedder.rb +43 -0
- data/lib/gte/model.rb +1 -9
- data/lib/gte/reranker.rb +6 -2
- data/lib/gte.rb +4 -1
- metadata +3 -2
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 29659e3ab6072d858b1710a779c3d2e5981f7749782182d141ccd5e9790a1fbb
|
|
4
|
+
data.tar.gz: c42d51cfa1a2ba6a2e83249e8a725c978b11c7ef80c6d69f09a64e884be42031
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: ff2c2b1450a6e82c07aacd2ec98437f03678d56eef9c5516f904021a54f59b2ba5c42b8f6af22b5c4b2dacea98615b99bc54d2c7cdc4e8fbccc1abc195fe9975
|
|
7
|
+
data.tar.gz: 04ca056458d40e2ba7fabcdbcab415a087d54802fb3bd86748dc901c2cf0ecb44072fd1820a73e3dcaca097f165df3e70bab747b38340cd738876af5f0ea7645
|
data/README.md
CHANGED
|
@@ -41,6 +41,7 @@ custom = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
|
41
41
|
config.with(
|
|
42
42
|
output_tensor: "last_hidden_state",
|
|
43
43
|
max_length: 256,
|
|
44
|
+
padding: "batch_longest",
|
|
44
45
|
optimization_level: 3
|
|
45
46
|
)
|
|
46
47
|
end
|
|
@@ -55,12 +56,22 @@ Config fields and defaults:
|
|
|
55
56
|
- `normalize`: `true` (L2 normalization at Ruby-facing API)
|
|
56
57
|
- `output_tensor`: `nil` (auto-select output tensor)
|
|
57
58
|
- `max_length`: `nil` (uses tokenizer/model defaults)
|
|
59
|
+
- `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
|
|
60
|
+
- `execution_providers`: `nil` (falls back to `GTE_EXECUTION_PROVIDERS` / CPU default)
|
|
58
61
|
|
|
59
62
|
Notes:
|
|
60
63
|
|
|
61
64
|
- Return a `Config::Text` from the block (for example, `config.with(...)`).
|
|
62
65
|
- Model instances are cached by full config key; different config values create different cached instances.
|
|
63
66
|
|
|
67
|
+
Low-level embedder setup (without model cache):
|
|
68
|
+
|
|
69
|
+
```ruby
|
|
70
|
+
embedder = GTE::Embedder.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
71
|
+
config.with(threads: 0, execution_providers: "cpu")
|
|
72
|
+
end
|
|
73
|
+
```
|
|
74
|
+
|
|
64
75
|
## Reranker
|
|
65
76
|
|
|
66
77
|
Use `GTE::Reranker.config(model_dir)` for cross-encoder reranking.
|
|
@@ -97,6 +108,8 @@ Reranker config fields and defaults:
|
|
|
97
108
|
- `sigmoid`: `false` (set `true` if you want bounded [0,1] style scores)
|
|
98
109
|
- `output_tensor`: `nil`
|
|
99
110
|
- `max_length`: `nil`
|
|
111
|
+
- `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
|
|
112
|
+
- `execution_providers`: `nil`
|
|
100
113
|
|
|
101
114
|
## Runtime + Result Examples
|
|
102
115
|
|
|
@@ -123,14 +136,30 @@ Input policy is text-only. Graphs requiring unsupported multimodal inputs (such
|
|
|
123
136
|
|
|
124
137
|
## Execution Providers
|
|
125
138
|
|
|
126
|
-
Default
|
|
139
|
+
Default behavior is CPU fallback via ONNX Runtime's default provider (no explicit provider registration).
|
|
140
|
+
|
|
141
|
+
Configure providers with `GTE_EXECUTION_PROVIDERS` (comma-separated, case-insensitive).
|
|
142
|
+
Supported values:
|
|
143
|
+
|
|
144
|
+
- `cpu` or `none`: CPU fallback (skip explicit provider registration)
|
|
145
|
+
- `xnnpack`
|
|
146
|
+
- `coreml`
|
|
127
147
|
|
|
128
|
-
|
|
148
|
+
Examples:
|
|
129
149
|
|
|
130
150
|
```bash
|
|
151
|
+
export GTE_EXECUTION_PROVIDERS=cpu
|
|
131
152
|
export GTE_EXECUTION_PROVIDERS=xnnpack,coreml
|
|
132
153
|
```
|
|
133
154
|
|
|
155
|
+
Ruby per-instance override (takes precedence over `GTE_EXECUTION_PROVIDERS`):
|
|
156
|
+
|
|
157
|
+
```ruby
|
|
158
|
+
model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
159
|
+
config.with(execution_providers: "cpu")
|
|
160
|
+
end
|
|
161
|
+
```
|
|
162
|
+
|
|
134
163
|
## Development
|
|
135
164
|
|
|
136
165
|
Run commands inside `nix develop` via Make targets:
|
|
@@ -154,8 +183,13 @@ nix develop -c bundle exec rake bench:matrix_sweep
|
|
|
154
183
|
nix develop -c bundle exec ruby bench/memory_probe.rb --compare-pure
|
|
155
184
|
```
|
|
156
185
|
|
|
157
|
-
|
|
186
|
+
To run benchmark + append a `RUNS.md` entry + enforce goal checks:
|
|
158
187
|
|
|
159
188
|
```bash
|
|
160
189
|
make bench-record
|
|
161
190
|
```
|
|
191
|
+
|
|
192
|
+
`bench/runs_ledger.rb check` is goal-focused by default:
|
|
193
|
+
|
|
194
|
+
- Enforces goal metric (`response_time_p95` ratio threshold).
|
|
195
|
+
- Does not require current-version coverage in `RUNS.md` unless explicitly enabled.
|
data/Rakefile
CHANGED
|
@@ -56,7 +56,7 @@ namespace :bench do
|
|
|
56
56
|
)
|
|
57
57
|
end
|
|
58
58
|
|
|
59
|
-
desc 'Run Puma benchmark, append RUNS.md entry, and enforce goal
|
|
59
|
+
desc 'Run Puma benchmark, append RUNS.md entry, and enforce goal checks'
|
|
60
60
|
task :record_run do
|
|
61
61
|
run_in_nix(
|
|
62
62
|
'bundle', 'exec', 'ruby', 'bench/puma_compare.rb',
|
|
@@ -74,7 +74,7 @@ namespace :bench do
|
|
|
74
74
|
)
|
|
75
75
|
end
|
|
76
76
|
|
|
77
|
-
desc 'Validate current Puma benchmark output against 2x goal
|
|
77
|
+
desc 'Validate current Puma benchmark output against 2x goal only'
|
|
78
78
|
task :check_goal do
|
|
79
79
|
run_in_nix(
|
|
80
80
|
'bundle', 'exec', 'ruby', 'bench/runs_ledger.rb', 'check',
|
data/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.0.
|
|
1
|
+
0.0.7
|
data/ext/gte/Cargo.toml
CHANGED
data/ext/gte/src/embedder.rs
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
|
-
use crate::model_config::{ExtractorMode, ModelConfig};
|
|
2
|
+
use crate::model_config::{ExtractorMode, ModelConfig, ModelLoadOverrides, PaddingMode};
|
|
3
3
|
use crate::model_profile::{
|
|
4
|
-
has_input, infer_extraction_mode,
|
|
5
|
-
resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
|
|
4
|
+
has_input, infer_extraction_mode, read_tokenizer_profile, resolve_default_text_model,
|
|
5
|
+
resolve_named_model, resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
|
|
6
6
|
};
|
|
7
7
|
use crate::postprocess::normalize_l2 as normalize_l2_rows;
|
|
8
8
|
use crate::session::{build_session, run_session};
|
|
9
|
-
use crate::tokenizer::{Tokenized, Tokenizer};
|
|
9
|
+
use crate::tokenizer::{parse_padding_mode_override, Tokenized, Tokenizer};
|
|
10
10
|
use ndarray::Array2;
|
|
11
11
|
use ort::session::Session;
|
|
12
12
|
use std::path::Path;
|
|
@@ -23,7 +23,13 @@ impl Embedder {
|
|
|
23
23
|
P1: AsRef<Path>,
|
|
24
24
|
P2: AsRef<Path>,
|
|
25
25
|
{
|
|
26
|
-
let tokenizer = Tokenizer::new(
|
|
26
|
+
let tokenizer = Tokenizer::new(
|
|
27
|
+
tokenizer_path,
|
|
28
|
+
config.max_length,
|
|
29
|
+
config.with_type_ids,
|
|
30
|
+
config.padding_mode,
|
|
31
|
+
None,
|
|
32
|
+
)?;
|
|
27
33
|
let session = build_session(model_path, &config)?;
|
|
28
34
|
Ok(Self {
|
|
29
35
|
tokenizer,
|
|
@@ -36,9 +42,7 @@ impl Embedder {
|
|
|
36
42
|
dir: P,
|
|
37
43
|
num_threads: usize,
|
|
38
44
|
optimization_level: u8,
|
|
39
|
-
|
|
40
|
-
output_tensor_override: Option<&str>,
|
|
41
|
-
max_length_override: Option<usize>,
|
|
45
|
+
overrides: ModelLoadOverrides<'_>,
|
|
42
46
|
) -> Result<Self> {
|
|
43
47
|
const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] = [
|
|
44
48
|
"pooler_output",
|
|
@@ -49,30 +53,35 @@ impl Embedder {
|
|
|
49
53
|
|
|
50
54
|
let dir = dir.as_ref();
|
|
51
55
|
let tokenizer_path = resolve_tokenizer_path(dir)?;
|
|
52
|
-
let model_path = match model_name.filter(|s| !s.is_empty()) {
|
|
56
|
+
let model_path = match overrides.model_name.filter(|s| !s.is_empty()) {
|
|
53
57
|
Some(name) => resolve_named_model(dir, name)?,
|
|
54
58
|
None => resolve_default_text_model(dir)?,
|
|
55
59
|
};
|
|
56
60
|
|
|
57
|
-
let
|
|
61
|
+
let tokenizer_profile = read_tokenizer_profile(dir);
|
|
62
|
+
let max_length = if let Some(override_value) = overrides.max_length {
|
|
58
63
|
if override_value == 0 {
|
|
59
64
|
return Err(GteError::Inference(
|
|
60
65
|
"max_length override must be greater than 0".to_string(),
|
|
61
66
|
));
|
|
62
67
|
}
|
|
63
|
-
override_value
|
|
68
|
+
override_value.min(tokenizer_profile.safe_max_length)
|
|
64
69
|
} else {
|
|
65
|
-
|
|
70
|
+
tokenizer_profile.default_max_length
|
|
66
71
|
};
|
|
72
|
+
let padding_mode =
|
|
73
|
+
parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
|
|
67
74
|
|
|
68
75
|
let session_config = ModelConfig {
|
|
69
76
|
max_length,
|
|
77
|
+
padding_mode,
|
|
70
78
|
output_tensor: String::new(),
|
|
71
79
|
mode: ExtractorMode::Raw,
|
|
72
80
|
with_type_ids: false,
|
|
73
81
|
with_attention_mask: true,
|
|
74
82
|
num_threads,
|
|
75
83
|
optimization_level,
|
|
84
|
+
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
76
85
|
};
|
|
77
86
|
let session = build_session(&model_path, &session_config)?;
|
|
78
87
|
|
|
@@ -80,7 +89,7 @@ impl Embedder {
|
|
|
80
89
|
let with_type_ids = has_input(&session, "token_type_ids");
|
|
81
90
|
let with_attention_mask = has_input(&session, "attention_mask");
|
|
82
91
|
let output_tensor =
|
|
83
|
-
select_output_tensor(&session,
|
|
92
|
+
select_output_tensor(&session, overrides.output_tensor, &PREFERRED_EMBEDDING_OUTPUTS)?;
|
|
84
93
|
let mode = infer_extraction_mode(&session, output_tensor.as_str())?;
|
|
85
94
|
if matches!(mode, ExtractorMode::MeanPool) && !with_attention_mask {
|
|
86
95
|
return Err(GteError::Inference(
|
|
@@ -90,15 +99,23 @@ impl Embedder {
|
|
|
90
99
|
|
|
91
100
|
let config = ModelConfig {
|
|
92
101
|
max_length,
|
|
102
|
+
padding_mode,
|
|
93
103
|
output_tensor,
|
|
94
104
|
mode,
|
|
95
105
|
with_type_ids,
|
|
96
106
|
with_attention_mask,
|
|
97
107
|
num_threads,
|
|
98
108
|
optimization_level,
|
|
109
|
+
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
99
110
|
};
|
|
100
111
|
|
|
101
|
-
let tokenizer = Tokenizer::new(
|
|
112
|
+
let tokenizer = Tokenizer::new(
|
|
113
|
+
&tokenizer_path,
|
|
114
|
+
config.max_length,
|
|
115
|
+
config.with_type_ids,
|
|
116
|
+
config.padding_mode,
|
|
117
|
+
tokenizer_profile.fixed_padding_length,
|
|
118
|
+
)?;
|
|
102
119
|
|
|
103
120
|
Ok(Self {
|
|
104
121
|
tokenizer,
|
data/ext/gte/src/model_config.rs
CHANGED
|
@@ -5,13 +5,32 @@ pub enum ExtractorMode {
|
|
|
5
5
|
Raw,
|
|
6
6
|
}
|
|
7
7
|
|
|
8
|
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
|
9
|
+
pub enum PaddingMode {
|
|
10
|
+
#[default]
|
|
11
|
+
Auto,
|
|
12
|
+
BatchLongest,
|
|
13
|
+
Fixed,
|
|
14
|
+
}
|
|
15
|
+
|
|
8
16
|
#[derive(Debug, Clone)]
|
|
9
17
|
pub struct ModelConfig {
|
|
10
18
|
pub max_length: usize,
|
|
19
|
+
pub padding_mode: PaddingMode,
|
|
11
20
|
pub output_tensor: String,
|
|
12
21
|
pub mode: ExtractorMode,
|
|
13
22
|
pub with_type_ids: bool,
|
|
14
23
|
pub with_attention_mask: bool,
|
|
15
24
|
pub num_threads: usize,
|
|
16
25
|
pub optimization_level: u8,
|
|
26
|
+
pub execution_providers: Option<String>,
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
#[derive(Debug, Clone, Copy, Default)]
|
|
30
|
+
pub struct ModelLoadOverrides<'a> {
|
|
31
|
+
pub model_name: Option<&'a str>,
|
|
32
|
+
pub output_tensor: Option<&'a str>,
|
|
33
|
+
pub max_length: Option<usize>,
|
|
34
|
+
pub padding: Option<&'a str>,
|
|
35
|
+
pub execution_providers: Option<&'a str>,
|
|
17
36
|
}
|
|
@@ -1,9 +1,19 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
2
|
use crate::model_config::ExtractorMode;
|
|
3
3
|
use ort::session::Session;
|
|
4
|
+
use serde_json::Value;
|
|
4
5
|
use std::path::{Path, PathBuf};
|
|
5
6
|
|
|
6
7
|
const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
|
|
8
|
+
const DEFAULT_MAX_LENGTH: usize = 512;
|
|
9
|
+
const MAX_SUPPORTED_LENGTH: usize = 8192;
|
|
10
|
+
|
|
11
|
+
#[derive(Debug, Clone, Copy)]
|
|
12
|
+
pub struct TokenizerProfile {
|
|
13
|
+
pub default_max_length: usize,
|
|
14
|
+
pub safe_max_length: usize,
|
|
15
|
+
pub fixed_padding_length: Option<usize>,
|
|
16
|
+
}
|
|
7
17
|
|
|
8
18
|
pub fn resolve_tokenizer_path(dir: &Path) -> Result<PathBuf> {
|
|
9
19
|
let tokenizer_path = dir.join("tokenizer.json");
|
|
@@ -48,19 +58,78 @@ pub fn resolve_default_text_model(dir: &Path) -> Result<PathBuf> {
|
|
|
48
58
|
)))
|
|
49
59
|
}
|
|
50
60
|
|
|
51
|
-
pub fn
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
Some(
|
|
62
|
-
|
|
63
|
-
|
|
61
|
+
pub fn read_tokenizer_profile(dir: &Path) -> TokenizerProfile {
|
|
62
|
+
let tokenizer_config = read_json(dir.join("tokenizer_config.json"));
|
|
63
|
+
let tokenizer_json = read_json(dir.join("tokenizer.json"));
|
|
64
|
+
|
|
65
|
+
let fixed_padding_length = tokenizer_json
|
|
66
|
+
.as_ref()
|
|
67
|
+
.and_then(parse_fixed_padding_length_from_tokenizer_json);
|
|
68
|
+
|
|
69
|
+
let mut candidates = Vec::new();
|
|
70
|
+
if let Some(config) = tokenizer_config.as_ref() {
|
|
71
|
+
if let Some(v) = config.get("max_length").and_then(parse_positive_usize) {
|
|
72
|
+
candidates.push(v.min(MAX_SUPPORTED_LENGTH));
|
|
73
|
+
}
|
|
74
|
+
if let Some(v) = config.get("model_max_length").and_then(parse_positive_usize) {
|
|
75
|
+
candidates.push(v.min(MAX_SUPPORTED_LENGTH));
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
if let Some(tokenizer) = tokenizer_json.as_ref() {
|
|
80
|
+
if let Some(v) = tokenizer
|
|
81
|
+
.get("truncation")
|
|
82
|
+
.and_then(|truncation| truncation.get("max_length"))
|
|
83
|
+
.and_then(parse_positive_usize)
|
|
84
|
+
{
|
|
85
|
+
candidates.push(v.min(MAX_SUPPORTED_LENGTH));
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
if let Some(v) = fixed_padding_length {
|
|
90
|
+
candidates.push(v.min(MAX_SUPPORTED_LENGTH));
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
let default_max_length = candidates
|
|
94
|
+
.iter()
|
|
95
|
+
.copied()
|
|
96
|
+
.min()
|
|
97
|
+
.unwrap_or(DEFAULT_MAX_LENGTH)
|
|
98
|
+
.max(1);
|
|
99
|
+
let safe_max_length = fixed_padding_length.unwrap_or(default_max_length).max(1);
|
|
100
|
+
|
|
101
|
+
TokenizerProfile {
|
|
102
|
+
default_max_length,
|
|
103
|
+
safe_max_length,
|
|
104
|
+
fixed_padding_length,
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
fn read_json(path: PathBuf) -> Option<Value> {
|
|
109
|
+
let contents = std::fs::read_to_string(path).ok()?;
|
|
110
|
+
serde_json::from_str(&contents).ok()
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
fn parse_positive_usize(value: &Value) -> Option<usize> {
|
|
114
|
+
let raw = value
|
|
115
|
+
.as_u64()
|
|
116
|
+
.or_else(|| {
|
|
117
|
+
value
|
|
118
|
+
.as_f64()
|
|
119
|
+
.filter(|&v| v.is_finite() && v > 0.0)
|
|
120
|
+
.map(|v| v as u64)
|
|
121
|
+
})
|
|
122
|
+
.or_else(|| value.as_str().and_then(|s| s.parse::<u64>().ok()))?;
|
|
123
|
+
let parsed = usize::try_from(raw).ok()?;
|
|
124
|
+
(parsed > 0).then_some(parsed)
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
fn parse_fixed_padding_length_from_tokenizer_json(tokenizer_json: &Value) -> Option<usize> {
|
|
128
|
+
tokenizer_json
|
|
129
|
+
.get("padding")
|
|
130
|
+
.and_then(|padding| padding.get("strategy"))
|
|
131
|
+
.and_then(|strategy| strategy.get("Fixed"))
|
|
132
|
+
.and_then(parse_positive_usize)
|
|
64
133
|
}
|
|
65
134
|
|
|
66
135
|
pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Result<()> {
|
|
@@ -177,3 +246,32 @@ pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<E
|
|
|
177
246
|
))),
|
|
178
247
|
}
|
|
179
248
|
}
|
|
249
|
+
|
|
250
|
+
#[cfg(test)]
|
|
251
|
+
mod tests {
|
|
252
|
+
use super::{parse_fixed_padding_length_from_tokenizer_json, parse_positive_usize};
|
|
253
|
+
use serde_json::json;
|
|
254
|
+
|
|
255
|
+
#[test]
|
|
256
|
+
fn parse_positive_usize_handles_integer_float_and_string() {
|
|
257
|
+
assert_eq!(parse_positive_usize(&json!(64)), Some(64));
|
|
258
|
+
assert_eq!(parse_positive_usize(&json!(64.0)), Some(64));
|
|
259
|
+
assert_eq!(parse_positive_usize(&json!("64")), Some(64));
|
|
260
|
+
assert_eq!(parse_positive_usize(&json!(0)), None);
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
#[test]
|
|
264
|
+
fn parse_fixed_padding_length_reads_fixed_padding_strategy() {
|
|
265
|
+
let tokenizer_json = json!({
|
|
266
|
+
"padding": {
|
|
267
|
+
"strategy": {
|
|
268
|
+
"Fixed": 64
|
|
269
|
+
}
|
|
270
|
+
}
|
|
271
|
+
});
|
|
272
|
+
assert_eq!(
|
|
273
|
+
parse_fixed_padding_length_from_tokenizer_json(&tokenizer_json),
|
|
274
|
+
Some(64)
|
|
275
|
+
);
|
|
276
|
+
}
|
|
277
|
+
}
|
data/ext/gte/src/reranker.rs
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
|
+
use crate::model_config::{ModelLoadOverrides, PaddingMode};
|
|
2
3
|
use crate::model_profile::{
|
|
3
|
-
has_input,
|
|
4
|
-
select_output_tensor, validate_supported_text_inputs,
|
|
4
|
+
has_input, read_tokenizer_profile, resolve_default_text_model, resolve_named_model,
|
|
5
|
+
resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
|
|
5
6
|
};
|
|
6
7
|
use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
7
8
|
use crate::postprocess::sigmoid_scores;
|
|
8
9
|
use crate::session::build_session;
|
|
9
|
-
use crate::tokenizer::Tokenizer;
|
|
10
|
-
use ndarray::Array1;
|
|
10
|
+
use crate::tokenizer::{parse_padding_mode_override, Tokenizer};
|
|
11
11
|
use ort::session::Session;
|
|
12
12
|
use std::path::Path;
|
|
13
13
|
|
|
14
14
|
#[derive(Debug, Clone)]
|
|
15
15
|
struct RerankerConfig {
|
|
16
16
|
max_length: usize,
|
|
17
|
+
padding_mode: PaddingMode,
|
|
17
18
|
output_tensor: String,
|
|
18
19
|
with_type_ids: bool,
|
|
19
20
|
with_attention_mask: bool,
|
|
@@ -30,52 +31,62 @@ impl Reranker {
|
|
|
30
31
|
dir: P,
|
|
31
32
|
num_threads: usize,
|
|
32
33
|
optimization_level: u8,
|
|
33
|
-
|
|
34
|
-
output_tensor_override: Option<&str>,
|
|
35
|
-
max_length_override: Option<usize>,
|
|
34
|
+
overrides: ModelLoadOverrides<'_>,
|
|
36
35
|
) -> Result<Self> {
|
|
37
36
|
let dir = dir.as_ref();
|
|
38
37
|
let tokenizer_path = resolve_tokenizer_path(dir)?;
|
|
39
|
-
let model_path = match model_name.filter(|s| !s.is_empty()) {
|
|
38
|
+
let model_path = match overrides.model_name.filter(|s| !s.is_empty()) {
|
|
40
39
|
Some(name) => resolve_named_model(dir, name)?,
|
|
41
40
|
None => resolve_default_text_model(dir)?,
|
|
42
41
|
};
|
|
43
42
|
|
|
44
|
-
let
|
|
43
|
+
let tokenizer_profile = read_tokenizer_profile(dir);
|
|
44
|
+
let max_length = if let Some(override_value) = overrides.max_length {
|
|
45
45
|
if override_value == 0 {
|
|
46
46
|
return Err(GteError::Inference(
|
|
47
47
|
"max_length override must be greater than 0".to_string(),
|
|
48
48
|
));
|
|
49
49
|
}
|
|
50
|
-
override_value
|
|
50
|
+
override_value.min(tokenizer_profile.safe_max_length)
|
|
51
51
|
} else {
|
|
52
|
-
|
|
52
|
+
tokenizer_profile.default_max_length
|
|
53
53
|
};
|
|
54
|
+
let padding_mode =
|
|
55
|
+
parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
|
|
54
56
|
|
|
55
57
|
let probe_config = crate::model_config::ModelConfig {
|
|
56
58
|
max_length,
|
|
59
|
+
padding_mode,
|
|
57
60
|
output_tensor: String::new(),
|
|
58
61
|
mode: crate::model_config::ExtractorMode::Raw,
|
|
59
62
|
with_type_ids: false,
|
|
60
63
|
with_attention_mask: true,
|
|
61
64
|
num_threads,
|
|
62
65
|
optimization_level,
|
|
66
|
+
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
63
67
|
};
|
|
64
68
|
let session = build_session(&model_path, &probe_config)?;
|
|
65
69
|
|
|
66
70
|
validate_supported_text_inputs(&session, "text reranking")?;
|
|
67
71
|
let with_type_ids = has_input(&session, "token_type_ids");
|
|
68
72
|
let with_attention_mask = has_input(&session, "attention_mask");
|
|
69
|
-
let output_tensor = select_output_tensor(&session,
|
|
73
|
+
let output_tensor = select_output_tensor(&session, overrides.output_tensor, &["logits"])?;
|
|
70
74
|
|
|
71
75
|
let config = RerankerConfig {
|
|
72
76
|
max_length,
|
|
77
|
+
padding_mode,
|
|
73
78
|
output_tensor,
|
|
74
79
|
with_type_ids,
|
|
75
80
|
with_attention_mask,
|
|
76
81
|
};
|
|
77
82
|
|
|
78
|
-
let tokenizer = Tokenizer::new(
|
|
83
|
+
let tokenizer = Tokenizer::new(
|
|
84
|
+
&tokenizer_path,
|
|
85
|
+
config.max_length,
|
|
86
|
+
config.with_type_ids,
|
|
87
|
+
config.padding_mode,
|
|
88
|
+
tokenizer_profile.fixed_padding_length,
|
|
89
|
+
)?;
|
|
79
90
|
|
|
80
91
|
Ok(Self {
|
|
81
92
|
tokenizer,
|
|
@@ -84,14 +95,27 @@ impl Reranker {
|
|
|
84
95
|
})
|
|
85
96
|
}
|
|
86
97
|
|
|
87
|
-
pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<
|
|
98
|
+
pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<Vec<f32>> {
|
|
88
99
|
let tokenized = self.tokenizer.tokenize_pairs(pairs)?;
|
|
89
|
-
|
|
100
|
+
self.score_tokenized(&tokenized, apply_sigmoid)
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
pub fn score(&self, query: &str, candidates: &[String], apply_sigmoid: bool) -> Result<Vec<f32>> {
|
|
104
|
+
let tokenized = self.tokenizer.tokenize_query_candidates(query, candidates)?;
|
|
105
|
+
self.score_tokenized(&tokenized, apply_sigmoid)
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
fn score_tokenized(
|
|
109
|
+
&self,
|
|
110
|
+
tokenized: &crate::tokenizer::Tokenized,
|
|
111
|
+
apply_sigmoid: bool,
|
|
112
|
+
) -> Result<Vec<f32>> {
|
|
113
|
+
let input_tensors = InputTensors::from_tokenized(tokenized, self.config.with_attention_mask)?;
|
|
90
114
|
let outputs = self.session.run(input_tensors.inputs)?;
|
|
91
115
|
let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
|
|
92
116
|
|
|
93
117
|
let mut scores = match array.ndim() {
|
|
94
|
-
1 => array.into_dimensionality::<ndarray::Ix1>()?.
|
|
118
|
+
1 => array.into_dimensionality::<ndarray::Ix1>()?.to_vec(),
|
|
95
119
|
2 => {
|
|
96
120
|
let shape = array.shape();
|
|
97
121
|
if shape[1] == 0 {
|
|
@@ -100,7 +124,7 @@ impl Reranker {
|
|
|
100
124
|
self.config.output_tensor, shape
|
|
101
125
|
)));
|
|
102
126
|
}
|
|
103
|
-
array.slice(ndarray::s![.., 0]).
|
|
127
|
+
array.slice(ndarray::s![.., 0]).to_vec()
|
|
104
128
|
}
|
|
105
129
|
n => {
|
|
106
130
|
return Err(GteError::Inference(format!(
|
|
@@ -111,10 +135,9 @@ impl Reranker {
|
|
|
111
135
|
};
|
|
112
136
|
|
|
113
137
|
if apply_sigmoid {
|
|
114
|
-
sigmoid_scores(scores.
|
|
138
|
+
sigmoid_scores(ndarray::ArrayViewMut1::from(scores.as_mut_slice()));
|
|
115
139
|
}
|
|
116
140
|
|
|
117
141
|
Ok(scores)
|
|
118
142
|
}
|
|
119
|
-
|
|
120
143
|
}
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
use crate::embedder::{normalize_l2, Embedder};
|
|
4
4
|
use crate::error::GteError;
|
|
5
|
+
use crate::model_config::ModelLoadOverrides;
|
|
5
6
|
use crate::reranker::Reranker;
|
|
6
7
|
use magnus::{function, method, prelude::*, wrap, Error, RArray, Ruby};
|
|
7
8
|
use std::os::raw::c_void;
|
|
@@ -38,7 +39,8 @@ unsafe impl Send for InferArgs {}
|
|
|
38
39
|
|
|
39
40
|
struct ScoreArgs {
|
|
40
41
|
reranker: *const Reranker,
|
|
41
|
-
|
|
42
|
+
query: *const String,
|
|
43
|
+
candidates: *const Vec<String>,
|
|
42
44
|
apply_sigmoid: bool,
|
|
43
45
|
result: Option<Result<Vec<f32>, GteError>>,
|
|
44
46
|
}
|
|
@@ -85,13 +87,15 @@ fn infer_without_gvl(
|
|
|
85
87
|
|
|
86
88
|
fn score_without_gvl(
|
|
87
89
|
reranker: &Arc<Reranker>,
|
|
88
|
-
|
|
90
|
+
query: String,
|
|
91
|
+
candidates: Vec<String>,
|
|
89
92
|
apply_sigmoid: bool,
|
|
90
93
|
) -> Result<Vec<f32>, Error> {
|
|
91
94
|
let scores = unsafe {
|
|
92
95
|
let mut args = ScoreArgs {
|
|
93
96
|
reranker: Arc::as_ptr(reranker),
|
|
94
|
-
|
|
97
|
+
query: &query as *const String,
|
|
98
|
+
candidates: &candidates as *const Vec<String>,
|
|
95
99
|
apply_sigmoid,
|
|
96
100
|
result: None,
|
|
97
101
|
};
|
|
@@ -135,8 +139,7 @@ unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
|
135
139
|
unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
136
140
|
let args = &mut *(ptr as *mut ScoreArgs);
|
|
137
141
|
let run_result = catch_unwind(AssertUnwindSafe(|| {
|
|
138
|
-
|
|
139
|
-
Ok(scores.to_vec())
|
|
142
|
+
(*args.reranker).score(&*args.query, &*args.candidates, args.apply_sigmoid)
|
|
140
143
|
}));
|
|
141
144
|
args.result = Some(match run_result {
|
|
142
145
|
Ok(result) => result,
|
|
@@ -171,6 +174,8 @@ impl RbEmbedder {
|
|
|
171
174
|
normalize: bool,
|
|
172
175
|
output_tensor: String,
|
|
173
176
|
max_length: usize,
|
|
177
|
+
padding: String,
|
|
178
|
+
execution_providers: String,
|
|
174
179
|
) -> Result<Self, Error> {
|
|
175
180
|
let name = if model_name.is_empty() {
|
|
176
181
|
None
|
|
@@ -187,13 +192,28 @@ impl RbEmbedder {
|
|
|
187
192
|
} else {
|
|
188
193
|
Some(max_length)
|
|
189
194
|
};
|
|
195
|
+
let execution_providers_override = if execution_providers.is_empty() {
|
|
196
|
+
None
|
|
197
|
+
} else {
|
|
198
|
+
Some(execution_providers.as_str())
|
|
199
|
+
};
|
|
200
|
+
let padding_override = if padding.is_empty() {
|
|
201
|
+
None
|
|
202
|
+
} else {
|
|
203
|
+
Some(padding.as_str())
|
|
204
|
+
};
|
|
205
|
+
let overrides = ModelLoadOverrides {
|
|
206
|
+
model_name: name,
|
|
207
|
+
output_tensor: output_override,
|
|
208
|
+
max_length: max_length_override,
|
|
209
|
+
padding: padding_override,
|
|
210
|
+
execution_providers: execution_providers_override,
|
|
211
|
+
};
|
|
190
212
|
let embedder = Embedder::from_dir(
|
|
191
213
|
&dir_path,
|
|
192
214
|
num_threads,
|
|
193
215
|
optimization_level,
|
|
194
|
-
|
|
195
|
-
output_override,
|
|
196
|
-
max_length_override,
|
|
216
|
+
overrides,
|
|
197
217
|
)
|
|
198
218
|
.map_err(magnus::Error::from)?;
|
|
199
219
|
Ok(RbEmbedder {
|
|
@@ -224,6 +244,8 @@ impl RbReranker {
|
|
|
224
244
|
sigmoid: bool,
|
|
225
245
|
output_tensor: String,
|
|
226
246
|
max_length: usize,
|
|
247
|
+
padding: String,
|
|
248
|
+
execution_providers: String,
|
|
227
249
|
) -> Result<Self, Error> {
|
|
228
250
|
let name = if model_name.is_empty() {
|
|
229
251
|
None
|
|
@@ -240,13 +262,28 @@ impl RbReranker {
|
|
|
240
262
|
} else {
|
|
241
263
|
Some(max_length)
|
|
242
264
|
};
|
|
265
|
+
let execution_providers_override = if execution_providers.is_empty() {
|
|
266
|
+
None
|
|
267
|
+
} else {
|
|
268
|
+
Some(execution_providers.as_str())
|
|
269
|
+
};
|
|
270
|
+
let padding_override = if padding.is_empty() {
|
|
271
|
+
None
|
|
272
|
+
} else {
|
|
273
|
+
Some(padding.as_str())
|
|
274
|
+
};
|
|
275
|
+
let overrides = ModelLoadOverrides {
|
|
276
|
+
model_name: name,
|
|
277
|
+
output_tensor: output_override,
|
|
278
|
+
max_length: max_length_override,
|
|
279
|
+
padding: padding_override,
|
|
280
|
+
execution_providers: execution_providers_override,
|
|
281
|
+
};
|
|
243
282
|
let reranker = Reranker::from_dir(
|
|
244
283
|
&dir_path,
|
|
245
284
|
num_threads,
|
|
246
285
|
optimization_level,
|
|
247
|
-
|
|
248
|
-
output_override,
|
|
249
|
-
max_length_override,
|
|
286
|
+
overrides,
|
|
250
287
|
)
|
|
251
288
|
.map_err(magnus::Error::from)?;
|
|
252
289
|
Ok(RbReranker {
|
|
@@ -262,11 +299,7 @@ impl RbReranker {
|
|
|
262
299
|
candidates: RArray,
|
|
263
300
|
) -> Result<RArray, Error> {
|
|
264
301
|
let candidates: Vec<String> = candidates.to_vec()?;
|
|
265
|
-
let
|
|
266
|
-
.into_iter()
|
|
267
|
-
.map(|candidate| (query.clone(), candidate))
|
|
268
|
-
.collect();
|
|
269
|
-
let scores = score_without_gvl(&rb_self.inner, pairs, rb_self.sigmoid)?;
|
|
302
|
+
let scores = score_without_gvl(&rb_self.inner, query, candidates, rb_self.sigmoid)?;
|
|
270
303
|
|
|
271
304
|
let out = ruby.ary_new_capa(scores.len());
|
|
272
305
|
for score in scores {
|
|
@@ -362,12 +395,12 @@ impl RbTensor {
|
|
|
362
395
|
pub fn register(ruby: &Ruby) -> Result<(), Error> {
|
|
363
396
|
let module = ruby.define_module("GTE")?;
|
|
364
397
|
let embedder_class = module.define_class("Embedder", ruby.class_object())?;
|
|
365
|
-
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new,
|
|
398
|
+
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 9))?;
|
|
366
399
|
embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
|
|
367
400
|
embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
|
|
368
401
|
|
|
369
402
|
let reranker_class = module.define_class("Reranker", ruby.class_object())?;
|
|
370
|
-
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new,
|
|
403
|
+
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 9))?;
|
|
371
404
|
reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
|
|
372
405
|
|
|
373
406
|
let tensor_class = module.define_class("Tensor", ruby.class_object())?;
|
data/ext/gte/src/session.rs
CHANGED
|
@@ -22,7 +22,7 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
|
|
|
22
22
|
.with_optimization_level(opt_level)?
|
|
23
23
|
.with_memory_pattern(true)?;
|
|
24
24
|
|
|
25
|
-
let providers = preferred_execution_providers();
|
|
25
|
+
let providers = preferred_execution_providers(config.execution_providers.as_deref());
|
|
26
26
|
if !providers.is_empty() {
|
|
27
27
|
builder = builder.with_execution_providers(providers)?;
|
|
28
28
|
}
|
|
@@ -34,19 +34,40 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
|
|
|
34
34
|
Ok(builder.commit_from_file(model_path)?)
|
|
35
35
|
}
|
|
36
36
|
|
|
37
|
-
fn preferred_execution_providers() -> Vec<ExecutionProviderDispatch> {
|
|
38
|
-
let order =
|
|
39
|
-
.unwrap_or_else(|_| "xnnpack".to_string())
|
|
40
|
-
.to_ascii_lowercase();
|
|
37
|
+
fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
|
|
38
|
+
let order = resolve_provider_order(order_override);
|
|
41
39
|
|
|
42
40
|
let mut providers = Vec::new();
|
|
43
|
-
for provider in order.
|
|
41
|
+
for provider in parse_provider_registrations(order.as_str()) {
|
|
44
42
|
match provider {
|
|
45
43
|
"xnnpack" => {
|
|
46
44
|
providers.push(XNNPACKExecutionProvider::default().build().fail_silently())
|
|
47
45
|
}
|
|
48
46
|
"coreml" => providers.push(CoreMLExecutionProvider::default().build().fail_silently()),
|
|
49
|
-
|
|
47
|
+
_ => {}
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
providers
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
fn resolve_provider_order(order_override: Option<&str>) -> String {
|
|
54
|
+
let env_order = std::env::var("GTE_EXECUTION_PROVIDERS").ok();
|
|
55
|
+
resolve_provider_order_with_env(order_override, env_order.as_deref())
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
|
|
59
|
+
order_override
|
|
60
|
+
.or(env_order)
|
|
61
|
+
.unwrap_or("cpu")
|
|
62
|
+
.to_ascii_lowercase()
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
fn parse_provider_registrations(order: &str) -> Vec<&str> {
|
|
66
|
+
let mut providers = Vec::new();
|
|
67
|
+
for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
|
|
68
|
+
match provider {
|
|
69
|
+
"xnnpack" | "coreml" => providers.push(provider),
|
|
70
|
+
"none" | "cpu" => {}
|
|
50
71
|
_ => {}
|
|
51
72
|
}
|
|
52
73
|
}
|
|
@@ -86,3 +107,42 @@ pub fn run_session(
|
|
|
86
107
|
ExtractorMode::Raw => Ok(array.into_dimensionality::<Ix2>()?.into_owned()),
|
|
87
108
|
}
|
|
88
109
|
}
|
|
110
|
+
|
|
111
|
+
#[cfg(test)]
|
|
112
|
+
mod tests {
|
|
113
|
+
use super::{parse_provider_registrations, resolve_provider_order_with_env};
|
|
114
|
+
|
|
115
|
+
#[test]
|
|
116
|
+
fn parse_provider_registrations_keeps_supported_order() {
|
|
117
|
+
let parsed = parse_provider_registrations("xnnpack,coreml");
|
|
118
|
+
assert_eq!(parsed, vec!["xnnpack", "coreml"]);
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
#[test]
|
|
122
|
+
fn parse_provider_registrations_treats_cpu_and_none_as_fallback() {
|
|
123
|
+
assert!(parse_provider_registrations("cpu").is_empty());
|
|
124
|
+
assert!(parse_provider_registrations("none").is_empty());
|
|
125
|
+
assert!(parse_provider_registrations("none,cpu").is_empty());
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
#[test]
|
|
129
|
+
fn parse_provider_registrations_ignores_unknowns_and_empties() {
|
|
130
|
+
let parsed = parse_provider_registrations(" ,xnnpak,,xnnpack,unknown,coreml,");
|
|
131
|
+
assert_eq!(parsed, vec!["xnnpack", "coreml"]);
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
#[test]
|
|
135
|
+
fn resolve_provider_order_prefers_override() {
|
|
136
|
+
assert_eq!(
|
|
137
|
+
resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")),
|
|
138
|
+
"xnnpack"
|
|
139
|
+
);
|
|
140
|
+
assert_eq!(resolve_provider_order_with_env(Some("CPU"), None), "cpu");
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
#[test]
|
|
144
|
+
fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
|
|
145
|
+
assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
|
|
146
|
+
assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
|
|
147
|
+
}
|
|
148
|
+
}
|
data/ext/gte/src/tokenizer.rs
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
|
+
use crate::model_config::PaddingMode;
|
|
2
3
|
use std::path::Path;
|
|
3
4
|
use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
|
|
4
5
|
|
|
@@ -20,6 +21,8 @@ impl Tokenizer {
|
|
|
20
21
|
tokenizer_path: P,
|
|
21
22
|
max_length: usize,
|
|
22
23
|
with_type_ids: bool,
|
|
24
|
+
padding_mode: PaddingMode,
|
|
25
|
+
fixed_padding_length: Option<usize>,
|
|
23
26
|
) -> Result<Self> {
|
|
24
27
|
let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
|
|
25
28
|
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
@@ -33,7 +36,7 @@ impl Tokenizer {
|
|
|
33
36
|
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
34
37
|
|
|
35
38
|
let padding = PaddingParams {
|
|
36
|
-
strategy:
|
|
39
|
+
strategy: resolve_padding_strategy(padding_mode, max_length, fixed_padding_length),
|
|
37
40
|
..Default::default()
|
|
38
41
|
};
|
|
39
42
|
tokenizer.with_padding(Some(padding));
|
|
@@ -73,6 +76,56 @@ impl Tokenizer {
|
|
|
73
76
|
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
74
77
|
build_tokenized(&encodings, self.with_type_ids)
|
|
75
78
|
}
|
|
79
|
+
|
|
80
|
+
pub fn tokenize_query_candidates(&self, query: &str, candidates: &[String]) -> Result<Tokenized> {
|
|
81
|
+
let encode_inputs: Vec<tokenizers::EncodeInput<'_>> = candidates
|
|
82
|
+
.iter()
|
|
83
|
+
.map(|candidate| (query, candidate.as_str()).into())
|
|
84
|
+
.collect();
|
|
85
|
+
let encodings = self
|
|
86
|
+
.tokenizer
|
|
87
|
+
.encode_batch_fast(encode_inputs, true)
|
|
88
|
+
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
89
|
+
build_tokenized(&encodings, self.with_type_ids)
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
pub fn parse_padding_mode_override(value: Option<&str>) -> Result<Option<PaddingMode>> {
|
|
94
|
+
let Some(raw) = value.map(str::trim).filter(|v| !v.is_empty()) else {
|
|
95
|
+
return Ok(None);
|
|
96
|
+
};
|
|
97
|
+
|
|
98
|
+
let normalized = raw.to_ascii_lowercase().replace('-', "_");
|
|
99
|
+
let parsed = match normalized.as_str() {
|
|
100
|
+
"auto" => PaddingMode::Auto,
|
|
101
|
+
"batch_longest" | "batchlongest" => PaddingMode::BatchLongest,
|
|
102
|
+
"fixed" => PaddingMode::Fixed,
|
|
103
|
+
_ => {
|
|
104
|
+
return Err(GteError::Inference(format!(
|
|
105
|
+
"invalid padding mode '{}'; expected one of: auto, batch_longest, fixed",
|
|
106
|
+
raw
|
|
107
|
+
)))
|
|
108
|
+
}
|
|
109
|
+
};
|
|
110
|
+
Ok(Some(parsed))
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
fn resolve_padding_strategy(
|
|
114
|
+
padding_mode: PaddingMode,
|
|
115
|
+
max_length: usize,
|
|
116
|
+
fixed_padding_length: Option<usize>,
|
|
117
|
+
) -> PaddingStrategy {
|
|
118
|
+
match padding_mode {
|
|
119
|
+
PaddingMode::BatchLongest => PaddingStrategy::BatchLongest,
|
|
120
|
+
PaddingMode::Fixed => PaddingStrategy::Fixed(max_length),
|
|
121
|
+
PaddingMode::Auto => {
|
|
122
|
+
if fixed_padding_length.is_some() {
|
|
123
|
+
PaddingStrategy::Fixed(max_length)
|
|
124
|
+
} else {
|
|
125
|
+
PaddingStrategy::BatchLongest
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
}
|
|
76
129
|
}
|
|
77
130
|
|
|
78
131
|
fn build_tokenized_single(
|
|
@@ -121,21 +174,17 @@ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> R
|
|
|
121
174
|
let mut type_ids = with_type_ids.then(|| Vec::with_capacity(len));
|
|
122
175
|
|
|
123
176
|
for encoding in encodings {
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
);
|
|
177
|
+
for &value in encoding.get_ids() {
|
|
178
|
+
input_ids.push(i64::from(value));
|
|
179
|
+
}
|
|
180
|
+
for &value in encoding.get_attention_mask() {
|
|
181
|
+
attn_masks.push(i64::from(value));
|
|
182
|
+
}
|
|
131
183
|
|
|
132
184
|
if let Some(type_ids) = type_ids.as_mut() {
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
.iter()
|
|
137
|
-
.map(|&value| i64::from(value)),
|
|
138
|
-
);
|
|
185
|
+
for &value in encoding.get_type_ids() {
|
|
186
|
+
type_ids.push(i64::from(value));
|
|
187
|
+
}
|
|
139
188
|
}
|
|
140
189
|
}
|
|
141
190
|
|
|
@@ -147,3 +196,39 @@ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> R
|
|
|
147
196
|
type_ids,
|
|
148
197
|
})
|
|
149
198
|
}
|
|
199
|
+
|
|
200
|
+
#[cfg(test)]
|
|
201
|
+
mod tests {
|
|
202
|
+
use super::{parse_padding_mode_override, resolve_padding_strategy};
|
|
203
|
+
use crate::model_config::PaddingMode;
|
|
204
|
+
use tokenizers::PaddingStrategy;
|
|
205
|
+
|
|
206
|
+
#[test]
|
|
207
|
+
fn parse_padding_mode_override_accepts_expected_values() {
|
|
208
|
+
assert_eq!(
|
|
209
|
+
parse_padding_mode_override(Some("auto")).unwrap(),
|
|
210
|
+
Some(PaddingMode::Auto)
|
|
211
|
+
);
|
|
212
|
+
assert_eq!(
|
|
213
|
+
parse_padding_mode_override(Some("batch-longest")).unwrap(),
|
|
214
|
+
Some(PaddingMode::BatchLongest)
|
|
215
|
+
);
|
|
216
|
+
assert_eq!(
|
|
217
|
+
parse_padding_mode_override(Some("fixed")).unwrap(),
|
|
218
|
+
Some(PaddingMode::Fixed)
|
|
219
|
+
);
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
#[test]
|
|
223
|
+
fn parse_padding_mode_override_rejects_invalid_values() {
|
|
224
|
+
assert!(parse_padding_mode_override(Some("unknown")).is_err());
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
#[test]
|
|
228
|
+
fn resolve_padding_strategy_uses_fixed_for_auto_when_model_has_fixed_padding() {
|
|
229
|
+
match resolve_padding_strategy(PaddingMode::Auto, 64, Some(64)) {
|
|
230
|
+
PaddingStrategy::Fixed(64) => {}
|
|
231
|
+
other => panic!("expected Fixed(64), got {:?}", other),
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
}
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
use gte::embedder::Embedder;
|
|
2
|
+
use gte::model_config::ModelLoadOverrides;
|
|
2
3
|
|
|
3
4
|
#[test]
|
|
4
5
|
#[ignore = "requires ext/gte/tests/fixtures/e5/tokenizer.json and model.onnx"]
|
|
5
6
|
fn test_e5_single_embedding_shape() {
|
|
6
7
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
7
8
|
|
|
8
|
-
let embedder =
|
|
9
|
-
|
|
9
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
|
|
10
|
+
.expect("embedder should initialize");
|
|
10
11
|
let result = embedder
|
|
11
12
|
.embed(vec!["query: Hello world".to_string()])
|
|
12
13
|
.expect("embed should succeed");
|
|
@@ -20,8 +21,8 @@ fn test_e5_single_embedding_shape() {
|
|
|
20
21
|
fn test_clip_single_embedding_shape() {
|
|
21
22
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/clip");
|
|
22
23
|
|
|
23
|
-
let embedder =
|
|
24
|
-
|
|
24
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
|
|
25
|
+
.expect("embedder should initialize");
|
|
25
26
|
let result = embedder
|
|
26
27
|
.embed(vec!["a photo of a cat".to_string()])
|
|
27
28
|
.expect("embed should succeed");
|
|
@@ -35,8 +36,8 @@ fn test_clip_single_embedding_shape() {
|
|
|
35
36
|
fn test_e5_batch_embedding_shape() {
|
|
36
37
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
37
38
|
|
|
38
|
-
let embedder =
|
|
39
|
-
|
|
39
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
|
|
40
|
+
.expect("embedder should initialize");
|
|
40
41
|
let texts = vec![
|
|
41
42
|
"query: first sentence".to_string(),
|
|
42
43
|
"query: second sentence".to_string(),
|
|
@@ -54,8 +55,8 @@ fn test_e5_batch_embedding_shape() {
|
|
|
54
55
|
fn test_e5_long_input_truncation_no_error() {
|
|
55
56
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
56
57
|
|
|
57
|
-
let embedder =
|
|
58
|
-
|
|
58
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
|
|
59
|
+
.expect("embedder should initialize");
|
|
59
60
|
let very_long_text = "word ".repeat(1000);
|
|
60
61
|
let result = embedder
|
|
61
62
|
.embed(vec![very_long_text])
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
use gte::model_config::PaddingMode;
|
|
1
2
|
use gte::tokenizer::Tokenizer;
|
|
2
3
|
|
|
3
4
|
#[test]
|
|
@@ -8,7 +9,8 @@ fn test_e5_tokenizer_output_shape() {
|
|
|
8
9
|
"/tests/fixtures/e5/tokenizer.json"
|
|
9
10
|
);
|
|
10
11
|
|
|
11
|
-
let tokenizer = Tokenizer::new(TOKENIZER, 512, true
|
|
12
|
+
let tokenizer = Tokenizer::new(TOKENIZER, 512, true, PaddingMode::BatchLongest, None)
|
|
13
|
+
.expect("tokenizer should load");
|
|
12
14
|
let texts = vec![
|
|
13
15
|
"Hello, world!".to_string(),
|
|
14
16
|
"A second, longer sentence to test padding behavior.".to_string(),
|
|
@@ -33,7 +35,8 @@ fn test_e5_truncation_at_max_length() {
|
|
|
33
35
|
"/tests/fixtures/e5/tokenizer.json"
|
|
34
36
|
);
|
|
35
37
|
|
|
36
|
-
let tokenizer = Tokenizer::new(TOKENIZER, 16, false
|
|
38
|
+
let tokenizer = Tokenizer::new(TOKENIZER, 16, false, PaddingMode::BatchLongest, None)
|
|
39
|
+
.expect("tokenizer should load");
|
|
37
40
|
let long_text = "word ".repeat(200);
|
|
38
41
|
let tokenized = tokenizer
|
|
39
42
|
.tokenize(&[long_text])
|
data/lib/gte/config.rb
CHANGED
|
@@ -4,12 +4,12 @@ module GTE
|
|
|
4
4
|
module Config
|
|
5
5
|
Text = Data.define(
|
|
6
6
|
:model_dir, :threads, :optimization_level,
|
|
7
|
-
:model_name, :normalize, :output_tensor, :max_length
|
|
7
|
+
:model_name, :normalize, :output_tensor, :max_length, :padding, :execution_providers
|
|
8
8
|
)
|
|
9
9
|
|
|
10
10
|
Reranker = Data.define(
|
|
11
11
|
:model_dir, :threads, :optimization_level,
|
|
12
|
-
:model_name, :sigmoid, :output_tensor, :max_length
|
|
12
|
+
:model_name, :sigmoid, :output_tensor, :max_length, :padding, :execution_providers
|
|
13
13
|
)
|
|
14
14
|
end
|
|
15
15
|
end
|
data/lib/gte/embedder.rb
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module GTE
|
|
4
|
+
class Embedder
|
|
5
|
+
class << self
|
|
6
|
+
def config(model_dir)
|
|
7
|
+
cfg = default_config(model_dir)
|
|
8
|
+
cfg = yield(cfg) if block_given?
|
|
9
|
+
from_config(cfg)
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
def from_config(config)
|
|
13
|
+
new(
|
|
14
|
+
config.model_dir,
|
|
15
|
+
config.threads,
|
|
16
|
+
config.optimization_level,
|
|
17
|
+
config.model_name.to_s,
|
|
18
|
+
config.normalize,
|
|
19
|
+
config.output_tensor.to_s,
|
|
20
|
+
config.max_length || 0,
|
|
21
|
+
config.padding.to_s,
|
|
22
|
+
config.execution_providers.to_s
|
|
23
|
+
)
|
|
24
|
+
end
|
|
25
|
+
|
|
26
|
+
private
|
|
27
|
+
|
|
28
|
+
def default_config(model_dir)
|
|
29
|
+
Config::Text.new(
|
|
30
|
+
model_dir: File.expand_path(model_dir),
|
|
31
|
+
threads: 3,
|
|
32
|
+
optimization_level: 3,
|
|
33
|
+
model_name: nil,
|
|
34
|
+
normalize: true,
|
|
35
|
+
output_tensor: nil,
|
|
36
|
+
max_length: nil,
|
|
37
|
+
padding: nil,
|
|
38
|
+
execution_providers: nil
|
|
39
|
+
)
|
|
40
|
+
end
|
|
41
|
+
end
|
|
42
|
+
end
|
|
43
|
+
end
|
data/lib/gte/model.rb
CHANGED
|
@@ -8,15 +8,7 @@ module GTE
|
|
|
8
8
|
raise ArgumentError, 'config must be a GTE::Config::Text' unless config.is_a?(Config::Text)
|
|
9
9
|
|
|
10
10
|
@config = config
|
|
11
|
-
@embedder = GTE::Embedder.
|
|
12
|
-
config.model_dir,
|
|
13
|
-
config.threads,
|
|
14
|
-
config.optimization_level,
|
|
15
|
-
config.model_name.to_s,
|
|
16
|
-
config.normalize,
|
|
17
|
-
config.output_tensor.to_s,
|
|
18
|
-
config.max_length || 0
|
|
19
|
-
)
|
|
11
|
+
@embedder = GTE::Embedder.from_config(config)
|
|
20
12
|
end
|
|
21
13
|
|
|
22
14
|
def embed(texts)
|
data/lib/gte/reranker.rb
CHANGED
|
@@ -24,7 +24,9 @@ module GTE
|
|
|
24
24
|
model_name: nil,
|
|
25
25
|
sigmoid: false,
|
|
26
26
|
output_tensor: nil,
|
|
27
|
-
max_length: nil
|
|
27
|
+
max_length: nil,
|
|
28
|
+
padding: nil,
|
|
29
|
+
execution_providers: nil
|
|
28
30
|
)
|
|
29
31
|
end
|
|
30
32
|
|
|
@@ -36,7 +38,9 @@ module GTE
|
|
|
36
38
|
cfg.model_name.to_s,
|
|
37
39
|
cfg.sigmoid,
|
|
38
40
|
cfg.output_tensor.to_s,
|
|
39
|
-
cfg.max_length || 0
|
|
41
|
+
cfg.max_length || 0,
|
|
42
|
+
cfg.padding.to_s,
|
|
43
|
+
cfg.execution_providers.to_s
|
|
40
44
|
)
|
|
41
45
|
end
|
|
42
46
|
end
|
data/lib/gte.rb
CHANGED
|
@@ -9,6 +9,7 @@ rescue LoadError
|
|
|
9
9
|
end
|
|
10
10
|
|
|
11
11
|
require 'gte/config'
|
|
12
|
+
require 'gte/embedder'
|
|
12
13
|
require 'gte/model'
|
|
13
14
|
require 'gte/reranker'
|
|
14
15
|
|
|
@@ -25,7 +26,9 @@ module GTE
|
|
|
25
26
|
model_name: nil,
|
|
26
27
|
normalize: true,
|
|
27
28
|
output_tensor: nil,
|
|
28
|
-
max_length: nil
|
|
29
|
+
max_length: nil,
|
|
30
|
+
padding: nil,
|
|
31
|
+
execution_providers: nil
|
|
29
32
|
)
|
|
30
33
|
|
|
31
34
|
cfg = yield(cfg) if block_given?
|
metadata
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: gte
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.0.
|
|
4
|
+
version: 0.0.7
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- elcuervo
|
|
8
8
|
autorequire:
|
|
9
9
|
bindir: bin
|
|
10
10
|
cert_chain: []
|
|
11
|
-
date: 2026-04-
|
|
11
|
+
date: 2026-04-16 00:00:00.000000000 Z
|
|
12
12
|
dependencies:
|
|
13
13
|
- !ruby/object:Gem::Dependency
|
|
14
14
|
name: rake
|
|
@@ -114,6 +114,7 @@ files:
|
|
|
114
114
|
- ext/gte/tests/tokenizer_unit_test.rs
|
|
115
115
|
- lib/gte.rb
|
|
116
116
|
- lib/gte/config.rb
|
|
117
|
+
- lib/gte/embedder.rb
|
|
117
118
|
- lib/gte/model.rb
|
|
118
119
|
- lib/gte/reranker.rb
|
|
119
120
|
- lib/gte/version.rb
|