gte 0.0.4 → 0.0.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +150 -14
- data/Rakefile +2 -2
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +1 -1
- data/ext/gte/src/embedder.rs +38 -253
- data/ext/gte/src/lib.rs +3 -0
- data/ext/gte/src/model_config.rs +1 -0
- data/ext/gte/src/model_profile.rs +179 -0
- data/ext/gte/src/pipeline.rs +60 -0
- data/ext/gte/src/postprocess.rs +6 -0
- data/ext/gte/src/reranker.rs +122 -0
- data/ext/gte/src/ruby_embedder.rs +179 -7
- data/ext/gte/src/session.rs +76 -46
- data/ext/gte/src/tokenizer.rs +21 -2
- data/ext/gte/tests/inference_integration_test.rs +8 -4
- data/ext/gte/tests/postprocess_unit_test.rs +17 -0
- data/ext/gte/tests/tokenizer_unit_test.rs +4 -1
- data/lib/gte/config.rb +15 -0
- data/lib/gte/embedder.rb +41 -0
- data/lib/gte/model.rb +27 -0
- data/lib/gte/reranker.rb +56 -0
- data/lib/gte/version.rb +5 -0
- data/lib/gte.rb +26 -35
- metadata +11 -2
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: fc149108c647dc5b14154bfbdc4975b53670b9ed3cf7d80760cc2b415c935a48
|
|
4
|
+
data.tar.gz: 32a682a95d56c8fab8d0d64a7ada0c0347ae796b6aefe6191f9aca8fc96426c2
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: f5c69d954f51a51521b143b576942a9c0505ad60574c1727f963dd79e0b6c22cacc4e6d9af75394ae06f451521dbc788af51f1e79397a5cc66a41b4ce1b31933
|
|
7
|
+
data.tar.gz: 9e75fdbc9b5c8cfdd9d0e377a7e4a944057ec604e38ab23d960c4ed75ec6a72ce1dd27c2dd1bb2802721387babdabe0996e0c42be34d17d98253e0582b375de1
|
data/README.md
CHANGED
|
@@ -9,14 +9,115 @@ Inspired by https://github.com/fbilhaut/gte-rs
|
|
|
9
9
|
```ruby
|
|
10
10
|
require "gte"
|
|
11
11
|
|
|
12
|
-
model = GTE.
|
|
13
|
-
|
|
12
|
+
model = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
|
|
13
|
+
|
|
14
|
+
# String input => GTE::Tensor (1 row)
|
|
15
|
+
tensor = model.embed("query: hello world")
|
|
16
|
+
vector = tensor.row(0)
|
|
17
|
+
|
|
18
|
+
# [] with string => Array<Float> (single vector)
|
|
19
|
+
single = model["query: nearest coffee shop"]
|
|
20
|
+
|
|
21
|
+
# [] with array => GTE::Tensor (batch)
|
|
22
|
+
batch = model[["query: hello", "query: world"]]
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
## Embedding Config (`GTE.config`)
|
|
26
|
+
|
|
27
|
+
`GTE.config(model_dir)` builds (and caches) a `GTE::Model`.
|
|
28
|
+
|
|
29
|
+
```ruby
|
|
30
|
+
default_model = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
|
|
31
|
+
|
|
32
|
+
raw_model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
33
|
+
config.with(normalize: false)
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
full_throttle = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
37
|
+
config.with(threads: 0)
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
custom = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
41
|
+
config.with(
|
|
42
|
+
output_tensor: "last_hidden_state",
|
|
43
|
+
max_length: 256,
|
|
44
|
+
optimization_level: 3
|
|
45
|
+
)
|
|
46
|
+
end
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
Config fields and defaults:
|
|
50
|
+
|
|
51
|
+
- `model_dir`: absolute path to model directory
|
|
52
|
+
- `threads`: `3` (set `0` for ONNX Runtime full-throttle threadpool)
|
|
53
|
+
- `optimization_level`: `3`
|
|
54
|
+
- `model_name`: `nil`
|
|
55
|
+
- `normalize`: `true` (L2 normalization at Ruby-facing API)
|
|
56
|
+
- `output_tensor`: `nil` (auto-select output tensor)
|
|
57
|
+
- `max_length`: `nil` (uses tokenizer/model defaults)
|
|
58
|
+
- `execution_providers`: `nil` (falls back to `GTE_EXECUTION_PROVIDERS` / CPU default)
|
|
59
|
+
|
|
60
|
+
Notes:
|
|
61
|
+
|
|
62
|
+
- Return a `Config::Text` from the block (for example, `config.with(...)`).
|
|
63
|
+
- Model instances are cached by full config key; different config values create different cached instances.
|
|
64
|
+
|
|
65
|
+
Low-level embedder setup (without model cache):
|
|
66
|
+
|
|
67
|
+
```ruby
|
|
68
|
+
embedder = GTE::Embedder.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
69
|
+
config.with(threads: 0, execution_providers: "cpu")
|
|
70
|
+
end
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
## Reranker
|
|
74
|
+
|
|
75
|
+
Use `GTE::Reranker.config(model_dir)` for cross-encoder reranking.
|
|
76
|
+
|
|
77
|
+
```ruby
|
|
78
|
+
reranker = GTE::Reranker.config(ENV.fetch("GTE_RERANK_DIR")) do |config|
|
|
79
|
+
config.with(sigmoid: true, threads: 0)
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
query = "how to train a neural network?"
|
|
83
|
+
candidates = [
|
|
84
|
+
"Backpropagation and gradient descent are core techniques.",
|
|
85
|
+
"This recipe uses flour and eggs."
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
# Raw scores aligned with input order
|
|
89
|
+
scores = reranker.score(query, candidates)
|
|
90
|
+
# => [0.93, 0.07]
|
|
91
|
+
|
|
92
|
+
# Ranked output sorted by score desc
|
|
93
|
+
ranked = reranker.rerank(query: query, candidates: candidates)
|
|
94
|
+
# => [
|
|
95
|
+
# { index: 0, score: 0.93, text: "Backpropagation and gradient descent are core techniques." },
|
|
96
|
+
# { index: 1, score: 0.07, text: "This recipe uses flour and eggs." }
|
|
97
|
+
# ]
|
|
14
98
|
```
|
|
15
99
|
|
|
16
|
-
|
|
100
|
+
Reranker config fields and defaults:
|
|
101
|
+
|
|
102
|
+
- `model_dir`: absolute path to model directory
|
|
103
|
+
- `threads`: `3`
|
|
104
|
+
- `optimization_level`: `3`
|
|
105
|
+
- `model_name`: `nil`
|
|
106
|
+
- `sigmoid`: `false` (set `true` if you want bounded [0,1] style scores)
|
|
107
|
+
- `output_tensor`: `nil`
|
|
108
|
+
- `max_length`: `nil`
|
|
109
|
+
- `execution_providers`: `nil`
|
|
110
|
+
|
|
111
|
+
## Runtime + Result Examples
|
|
112
|
+
|
|
113
|
+
Process-local reuse (recommended for Puma/web servers):
|
|
17
114
|
|
|
18
115
|
```ruby
|
|
19
|
-
|
|
116
|
+
EMBEDDER = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
|
|
117
|
+
|
|
118
|
+
def embed_query(text)
|
|
119
|
+
EMBEDDER[text] # Array<Float>
|
|
120
|
+
end
|
|
20
121
|
```
|
|
21
122
|
|
|
22
123
|
## Model Directory
|
|
@@ -28,14 +129,44 @@ A model directory must include `tokenizer.json` and one ONNX model, resolved in
|
|
|
28
129
|
3. `onnx/model.onnx`
|
|
29
130
|
4. `model.onnx`
|
|
30
131
|
|
|
132
|
+
Input policy is text-only. Graphs requiring unsupported multimodal inputs (such as `pixel_values`) are intentionally rejected.
|
|
133
|
+
|
|
134
|
+
## Execution Providers
|
|
135
|
+
|
|
136
|
+
Default behavior is CPU fallback via ONNX Runtime's default provider (no explicit provider registration).
|
|
137
|
+
|
|
138
|
+
Configure providers with `GTE_EXECUTION_PROVIDERS` (comma-separated, case-insensitive).
|
|
139
|
+
Supported values:
|
|
140
|
+
|
|
141
|
+
- `cpu` or `none`: CPU fallback (skip explicit provider registration)
|
|
142
|
+
- `xnnpack`
|
|
143
|
+
- `coreml`
|
|
144
|
+
|
|
145
|
+
Examples:
|
|
146
|
+
|
|
147
|
+
```bash
|
|
148
|
+
export GTE_EXECUTION_PROVIDERS=cpu
|
|
149
|
+
export GTE_EXECUTION_PROVIDERS=xnnpack,coreml
|
|
150
|
+
```
|
|
151
|
+
|
|
152
|
+
Ruby per-instance override (takes precedence over `GTE_EXECUTION_PROVIDERS`):
|
|
153
|
+
|
|
154
|
+
```ruby
|
|
155
|
+
model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
156
|
+
config.with(execution_providers: "cpu")
|
|
157
|
+
end
|
|
158
|
+
```
|
|
159
|
+
|
|
31
160
|
## Development
|
|
32
161
|
|
|
33
|
-
Run commands inside `nix develop
|
|
162
|
+
Run commands inside `nix develop` via Make targets:
|
|
34
163
|
|
|
35
164
|
```bash
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
165
|
+
make setup
|
|
166
|
+
make compile
|
|
167
|
+
make test
|
|
168
|
+
make lint
|
|
169
|
+
make ci
|
|
39
170
|
```
|
|
40
171
|
|
|
41
172
|
## Benchmark
|
|
@@ -43,14 +174,19 @@ bundle exec rspec
|
|
|
43
174
|
The repo includes two benchmark paths:
|
|
44
175
|
|
|
45
176
|
```bash
|
|
46
|
-
|
|
47
|
-
bundle exec rake bench:
|
|
48
|
-
bundle exec rake bench:matrix_sweep
|
|
49
|
-
bundle exec ruby bench/memory_probe.rb --compare-pure
|
|
177
|
+
make bench
|
|
178
|
+
nix develop -c bundle exec rake bench:pure_compare
|
|
179
|
+
nix develop -c bundle exec rake bench:matrix_sweep
|
|
180
|
+
nix develop -c bundle exec ruby bench/memory_probe.rb --compare-pure
|
|
50
181
|
```
|
|
51
182
|
|
|
52
|
-
|
|
183
|
+
To run benchmark + append a `RUNS.md` entry + enforce goal checks:
|
|
53
184
|
|
|
54
185
|
```bash
|
|
55
|
-
|
|
186
|
+
make bench-record
|
|
56
187
|
```
|
|
188
|
+
|
|
189
|
+
`bench/runs_ledger.rb check` is goal-focused by default:
|
|
190
|
+
|
|
191
|
+
- Enforces goal metric (`response_time_p95` ratio threshold).
|
|
192
|
+
- Does not require current-version coverage in `RUNS.md` unless explicitly enabled.
|
data/Rakefile
CHANGED
|
@@ -56,7 +56,7 @@ namespace :bench do
|
|
|
56
56
|
)
|
|
57
57
|
end
|
|
58
58
|
|
|
59
|
-
desc 'Run Puma benchmark, append RUNS.md entry, and enforce goal
|
|
59
|
+
desc 'Run Puma benchmark, append RUNS.md entry, and enforce goal checks'
|
|
60
60
|
task :record_run do
|
|
61
61
|
run_in_nix(
|
|
62
62
|
'bundle', 'exec', 'ruby', 'bench/puma_compare.rb',
|
|
@@ -74,7 +74,7 @@ namespace :bench do
|
|
|
74
74
|
)
|
|
75
75
|
end
|
|
76
76
|
|
|
77
|
-
desc 'Validate current Puma benchmark output against 2x goal
|
|
77
|
+
desc 'Validate current Puma benchmark output against 2x goal only'
|
|
78
78
|
task :check_goal do
|
|
79
79
|
run_in_nix(
|
|
80
80
|
'bundle', 'exec', 'ruby', 'bench/runs_ledger.rb', 'check',
|
data/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.0.
|
|
1
|
+
0.0.6
|
data/ext/gte/Cargo.toml
CHANGED
data/ext/gte/src/embedder.rs
CHANGED
|
@@ -1,19 +1,15 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
2
|
use crate::model_config::{ExtractorMode, ModelConfig};
|
|
3
|
+
use crate::model_profile::{
|
|
4
|
+
has_input, infer_extraction_mode, read_max_length, resolve_default_text_model, resolve_named_model,
|
|
5
|
+
resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
|
|
6
|
+
};
|
|
3
7
|
use crate::postprocess::normalize_l2 as normalize_l2_rows;
|
|
4
8
|
use crate::session::{build_session, run_session};
|
|
5
9
|
use crate::tokenizer::{Tokenized, Tokenizer};
|
|
6
10
|
use ndarray::Array2;
|
|
7
11
|
use ort::session::Session;
|
|
8
|
-
use std::path::
|
|
9
|
-
|
|
10
|
-
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
11
|
-
pub enum ModelFamily {
|
|
12
|
-
E5Like,
|
|
13
|
-
SiglipLike,
|
|
14
|
-
ClipLike,
|
|
15
|
-
Other,
|
|
16
|
-
}
|
|
12
|
+
use std::path::Path;
|
|
17
13
|
|
|
18
14
|
pub struct Embedder {
|
|
19
15
|
tokenizer: Tokenizer,
|
|
@@ -41,39 +37,52 @@ impl Embedder {
|
|
|
41
37
|
num_threads: usize,
|
|
42
38
|
optimization_level: u8,
|
|
43
39
|
model_name: Option<&str>,
|
|
40
|
+
output_tensor_override: Option<&str>,
|
|
41
|
+
max_length_override: Option<usize>,
|
|
42
|
+
execution_providers_override: Option<&str>,
|
|
44
43
|
) -> Result<Self> {
|
|
44
|
+
const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] = [
|
|
45
|
+
"pooler_output",
|
|
46
|
+
"text_embeds",
|
|
47
|
+
"sentence_embedding",
|
|
48
|
+
"last_hidden_state",
|
|
49
|
+
];
|
|
50
|
+
|
|
45
51
|
let dir = dir.as_ref();
|
|
46
|
-
let tokenizer_path = dir
|
|
52
|
+
let tokenizer_path = resolve_tokenizer_path(dir)?;
|
|
47
53
|
let model_path = match model_name.filter(|s| !s.is_empty()) {
|
|
48
54
|
Some(name) => resolve_named_model(dir, name)?,
|
|
49
|
-
None =>
|
|
55
|
+
None => resolve_default_text_model(dir)?,
|
|
50
56
|
};
|
|
51
57
|
|
|
52
|
-
if
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
+
let max_length = if let Some(override_value) = max_length_override {
|
|
59
|
+
if override_value == 0 {
|
|
60
|
+
return Err(GteError::Inference(
|
|
61
|
+
"max_length override must be greater than 0".to_string(),
|
|
62
|
+
));
|
|
63
|
+
}
|
|
64
|
+
override_value
|
|
65
|
+
} else {
|
|
66
|
+
read_max_length(dir)
|
|
67
|
+
};
|
|
58
68
|
|
|
59
|
-
let
|
|
60
|
-
let probe_num_threads = if num_threads == 0 { 1 } else { num_threads };
|
|
61
|
-
let temp_config = ModelConfig {
|
|
69
|
+
let session_config = ModelConfig {
|
|
62
70
|
max_length,
|
|
63
71
|
output_tensor: String::new(),
|
|
64
72
|
mode: ExtractorMode::Raw,
|
|
65
73
|
with_type_ids: false,
|
|
66
74
|
with_attention_mask: true,
|
|
67
|
-
num_threads
|
|
75
|
+
num_threads,
|
|
68
76
|
optimization_level,
|
|
77
|
+
execution_providers: execution_providers_override.map(str::to_string),
|
|
69
78
|
};
|
|
70
|
-
let
|
|
79
|
+
let session = build_session(&model_path, &session_config)?;
|
|
71
80
|
|
|
72
|
-
|
|
73
|
-
let with_type_ids = session
|
|
74
|
-
let with_attention_mask = session
|
|
75
|
-
let output_tensor =
|
|
76
|
-
|
|
81
|
+
validate_supported_text_inputs(&session, "text embedding")?;
|
|
82
|
+
let with_type_ids = has_input(&session, "token_type_ids");
|
|
83
|
+
let with_attention_mask = has_input(&session, "attention_mask");
|
|
84
|
+
let output_tensor =
|
|
85
|
+
select_output_tensor(&session, output_tensor_override, &PREFERRED_EMBEDDING_OUTPUTS)?;
|
|
77
86
|
let mode = infer_extraction_mode(&session, output_tensor.as_str())?;
|
|
78
87
|
if matches!(mode, ExtractorMode::MeanPool) && !with_attention_mask {
|
|
79
88
|
return Err(GteError::Inference(
|
|
@@ -81,29 +90,17 @@ impl Embedder {
|
|
|
81
90
|
));
|
|
82
91
|
}
|
|
83
92
|
|
|
84
|
-
let tuned_num_threads = tune_num_threads(
|
|
85
|
-
num_threads,
|
|
86
|
-
with_attention_mask,
|
|
87
|
-
with_type_ids,
|
|
88
|
-
output_base.as_str(),
|
|
89
|
-
);
|
|
90
|
-
|
|
91
93
|
let config = ModelConfig {
|
|
92
94
|
max_length,
|
|
93
95
|
output_tensor,
|
|
94
96
|
mode,
|
|
95
97
|
with_type_ids,
|
|
96
98
|
with_attention_mask,
|
|
97
|
-
num_threads
|
|
99
|
+
num_threads,
|
|
98
100
|
optimization_level,
|
|
101
|
+
execution_providers: execution_providers_override.map(str::to_string),
|
|
99
102
|
};
|
|
100
103
|
|
|
101
|
-
if tuned_num_threads != probe_num_threads {
|
|
102
|
-
// Release probe session before rebuilding to minimize transient peak RSS.
|
|
103
|
-
drop(session);
|
|
104
|
-
session = build_session(&model_path, &config)?;
|
|
105
|
-
}
|
|
106
|
-
|
|
107
104
|
let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
|
|
108
105
|
|
|
109
106
|
Ok(Self {
|
|
@@ -125,218 +122,6 @@ impl Embedder {
|
|
|
125
122
|
pub fn run(&self, tokenized: &Tokenized) -> crate::error::Result<Array2<f32>> {
|
|
126
123
|
run_session(&self.session, tokenized, &self.config)
|
|
127
124
|
}
|
|
128
|
-
|
|
129
|
-
}
|
|
130
|
-
|
|
131
|
-
fn tune_num_threads(
|
|
132
|
-
requested: usize,
|
|
133
|
-
with_attention_mask: bool,
|
|
134
|
-
with_type_ids: bool,
|
|
135
|
-
output_name: &str,
|
|
136
|
-
) -> usize {
|
|
137
|
-
if requested > 0 {
|
|
138
|
-
return requested;
|
|
139
|
-
}
|
|
140
|
-
|
|
141
|
-
let family = infer_model_family(with_attention_mask, with_type_ids, output_name);
|
|
142
|
-
|
|
143
|
-
match family {
|
|
144
|
-
// Puma-like workloads typically run many concurrent single-item requests where
|
|
145
|
-
// one intra-op thread per request gives the best tail behavior.
|
|
146
|
-
ModelFamily::E5Like | ModelFamily::ClipLike => 1,
|
|
147
|
-
// Siglip2 text path benefits from a small intra-op pool under concurrency.
|
|
148
|
-
ModelFamily::SiglipLike => 3,
|
|
149
|
-
ModelFamily::Other => 0,
|
|
150
|
-
}
|
|
151
|
-
}
|
|
152
|
-
|
|
153
|
-
fn infer_model_family(
|
|
154
|
-
with_attention_mask: bool,
|
|
155
|
-
with_type_ids: bool,
|
|
156
|
-
output_name: &str,
|
|
157
|
-
) -> ModelFamily {
|
|
158
|
-
if output_name == "last_hidden_state" && with_attention_mask && with_type_ids {
|
|
159
|
-
return ModelFamily::E5Like;
|
|
160
|
-
}
|
|
161
|
-
if output_name == "last_hidden_state" && with_attention_mask && !with_type_ids {
|
|
162
|
-
return ModelFamily::SiglipLike;
|
|
163
|
-
}
|
|
164
|
-
if output_name == "text_embeds" && !with_attention_mask {
|
|
165
|
-
return ModelFamily::ClipLike;
|
|
166
|
-
}
|
|
167
|
-
ModelFamily::Other
|
|
168
|
-
}
|
|
169
|
-
|
|
170
|
-
fn resolve_named_model(dir: &Path, name: &str) -> Result<PathBuf> {
|
|
171
|
-
let candidates = [dir.join("onnx").join(name), dir.join(name)];
|
|
172
|
-
for path in &candidates {
|
|
173
|
-
if path.exists() {
|
|
174
|
-
return Ok(path.clone());
|
|
175
|
-
}
|
|
176
|
-
}
|
|
177
|
-
Err(GteError::Inference(format!(
|
|
178
|
-
"model '{}' not found in {} (checked onnx/{0} and {0})",
|
|
179
|
-
name,
|
|
180
|
-
dir.display()
|
|
181
|
-
)))
|
|
182
|
-
}
|
|
183
|
-
|
|
184
|
-
fn resolve_model_path(dir: &Path) -> Result<PathBuf> {
|
|
185
|
-
let candidates = [
|
|
186
|
-
dir.join("onnx").join("text_model.onnx"),
|
|
187
|
-
dir.join("text_model.onnx"),
|
|
188
|
-
dir.join("onnx").join("model.onnx"),
|
|
189
|
-
dir.join("model.onnx"),
|
|
190
|
-
];
|
|
191
|
-
for path in &candidates {
|
|
192
|
-
if path.exists() {
|
|
193
|
-
return Ok(path.clone());
|
|
194
|
-
}
|
|
195
|
-
}
|
|
196
|
-
Err(GteError::Inference(format!(
|
|
197
|
-
"no ONNX model found in {} (checked text_model.onnx and model.onnx)",
|
|
198
|
-
dir.display()
|
|
199
|
-
)))
|
|
200
|
-
}
|
|
201
|
-
|
|
202
|
-
const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
|
|
203
|
-
|
|
204
|
-
fn validate_supported_inputs(session: &Session) -> Result<()> {
|
|
205
|
-
let unsupported: Vec<String> = session
|
|
206
|
-
.inputs
|
|
207
|
-
.iter()
|
|
208
|
-
.filter(|i| !SUPPORTED_INPUTS.contains(&i.name.as_str()))
|
|
209
|
-
.map(|i| i.name.clone())
|
|
210
|
-
.collect();
|
|
211
|
-
|
|
212
|
-
if unsupported.is_empty() {
|
|
213
|
-
return Ok(());
|
|
214
|
-
}
|
|
215
|
-
|
|
216
|
-
let mut message = format!(
|
|
217
|
-
"unsupported model inputs for text embedding API: {}",
|
|
218
|
-
unsupported.join(", ")
|
|
219
|
-
);
|
|
220
|
-
if unsupported.iter().any(|n| n == "pixel_values") {
|
|
221
|
-
message.push_str(
|
|
222
|
-
". This looks like a multimodal graph. Provide a text-only export (for example onnx/text_model.onnx).",
|
|
223
|
-
);
|
|
224
|
-
} else {
|
|
225
|
-
message.push_str(". Supported inputs are: input_ids, attention_mask, token_type_ids.");
|
|
226
|
-
}
|
|
227
|
-
Err(GteError::Inference(message))
|
|
228
|
-
}
|
|
229
|
-
|
|
230
|
-
fn output_name_matches(name: &str, preferred: &str) -> bool {
|
|
231
|
-
let lower = name.to_ascii_lowercase();
|
|
232
|
-
lower == preferred || lower.ends_with(&format!("/{}", preferred))
|
|
233
|
-
}
|
|
234
|
-
|
|
235
|
-
fn select_output_tensor(session: &Session) -> Result<String> {
|
|
236
|
-
const PREFERRED: [&str; 4] = [
|
|
237
|
-
"text_embeds",
|
|
238
|
-
"pooler_output",
|
|
239
|
-
"sentence_embedding",
|
|
240
|
-
"last_hidden_state",
|
|
241
|
-
];
|
|
242
|
-
|
|
243
|
-
for preferred in PREFERRED {
|
|
244
|
-
if let Some(output) = session
|
|
245
|
-
.outputs
|
|
246
|
-
.iter()
|
|
247
|
-
.find(|o| output_name_matches(o.name.as_str(), preferred))
|
|
248
|
-
{
|
|
249
|
-
return Ok(output.name.clone());
|
|
250
|
-
}
|
|
251
|
-
}
|
|
252
|
-
|
|
253
|
-
session
|
|
254
|
-
.outputs
|
|
255
|
-
.first()
|
|
256
|
-
.map(|o| o.name.clone())
|
|
257
|
-
.ok_or_else(|| GteError::Inference("model has no outputs".into()))
|
|
258
|
-
}
|
|
259
|
-
|
|
260
|
-
fn read_max_length(dir: &Path) -> usize {
|
|
261
|
-
(|| -> Option<usize> {
|
|
262
|
-
let contents = std::fs::read_to_string(dir.join("tokenizer_config.json")).ok()?;
|
|
263
|
-
let json: serde_json::Value = serde_json::from_str(&contents).ok()?;
|
|
264
|
-
let v = json.get("model_max_length")?;
|
|
265
|
-
let n = v
|
|
266
|
-
.as_u64()
|
|
267
|
-
.or_else(|| v.as_f64().filter(|&f| f > 0.0 && f < 1e15).map(|f| f as u64))?;
|
|
268
|
-
Some((n as usize).min(8192))
|
|
269
|
-
})()
|
|
270
|
-
.unwrap_or(512)
|
|
271
|
-
}
|
|
272
|
-
|
|
273
|
-
#[cfg(test)]
|
|
274
|
-
mod tests {
|
|
275
|
-
use super::{infer_model_family, tune_num_threads, ModelFamily};
|
|
276
|
-
|
|
277
|
-
#[test]
|
|
278
|
-
fn infer_model_family_recognizes_known_signatures() {
|
|
279
|
-
assert_eq!(
|
|
280
|
-
infer_model_family(true, true, "last_hidden_state"),
|
|
281
|
-
ModelFamily::E5Like
|
|
282
|
-
);
|
|
283
|
-
assert_eq!(
|
|
284
|
-
infer_model_family(true, false, "last_hidden_state"),
|
|
285
|
-
ModelFamily::SiglipLike
|
|
286
|
-
);
|
|
287
|
-
assert_eq!(
|
|
288
|
-
infer_model_family(false, false, "text_embeds"),
|
|
289
|
-
ModelFamily::ClipLike
|
|
290
|
-
);
|
|
291
|
-
assert_eq!(infer_model_family(true, false, "pooler_output"), ModelFamily::Other);
|
|
292
|
-
}
|
|
293
|
-
|
|
294
|
-
#[test]
|
|
295
|
-
fn tune_num_threads_respects_requested_value() {
|
|
296
|
-
assert_eq!(tune_num_threads(7, true, true, "last_hidden_state"), 7);
|
|
297
|
-
}
|
|
298
|
-
|
|
299
|
-
#[test]
|
|
300
|
-
fn tune_num_threads_returns_ort_default_for_other_family() {
|
|
301
|
-
assert_eq!(tune_num_threads(0, true, false, "pooler_output"), 0);
|
|
302
|
-
}
|
|
303
|
-
}
|
|
304
|
-
|
|
305
|
-
fn output_basename(name: &str) -> &str {
|
|
306
|
-
name.rsplit('/').next().unwrap_or(name)
|
|
307
|
-
}
|
|
308
|
-
|
|
309
|
-
fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
|
|
310
|
-
let output = session
|
|
311
|
-
.outputs
|
|
312
|
-
.iter()
|
|
313
|
-
.find(|o| o.name == output_tensor)
|
|
314
|
-
.ok_or_else(|| {
|
|
315
|
-
GteError::Inference(format!(
|
|
316
|
-
"output tensor '{}' not found in model outputs",
|
|
317
|
-
output_tensor
|
|
318
|
-
))
|
|
319
|
-
})?;
|
|
320
|
-
|
|
321
|
-
let ndims = match &output.output_type {
|
|
322
|
-
ort::value::ValueType::Tensor { dimensions, .. } => dimensions.len(),
|
|
323
|
-
other => {
|
|
324
|
-
return Err(GteError::Inference(format!(
|
|
325
|
-
"output is not a tensor: {:?}",
|
|
326
|
-
other
|
|
327
|
-
)))
|
|
328
|
-
}
|
|
329
|
-
};
|
|
330
|
-
|
|
331
|
-
match (output_basename(output_tensor), ndims) {
|
|
332
|
-
("last_hidden_state", 3) => Ok(ExtractorMode::MeanPool),
|
|
333
|
-
(_, 2) => Ok(ExtractorMode::Raw),
|
|
334
|
-
(_, 3) => Ok(ExtractorMode::MeanPool),
|
|
335
|
-
(_, n) => Err(GteError::Inference(format!(
|
|
336
|
-
"unexpected output tensor rank {} for '{}': expected 2 (Raw) or 3 (MeanPool)",
|
|
337
|
-
n, output_tensor
|
|
338
|
-
))),
|
|
339
|
-
}
|
|
340
125
|
}
|
|
341
126
|
|
|
342
127
|
pub fn normalize_l2(embeddings: Array2<f32>) -> Array2<f32> {
|
data/ext/gte/src/lib.rs
CHANGED