gte 0.0.4 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +118 -13
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +1 -1
- data/ext/gte/src/embedder.rs +35 -253
- data/ext/gte/src/lib.rs +3 -0
- data/ext/gte/src/model_profile.rs +179 -0
- data/ext/gte/src/pipeline.rs +60 -0
- data/ext/gte/src/postprocess.rs +6 -0
- data/ext/gte/src/reranker.rs +120 -0
- data/ext/gte/src/ruby_embedder.rs +165 -7
- data/ext/gte/src/session.rs +9 -39
- data/ext/gte/src/tokenizer.rs +21 -2
- data/ext/gte/tests/inference_integration_test.rs +8 -4
- data/ext/gte/tests/postprocess_unit_test.rs +17 -0
- data/ext/gte/tests/tokenizer_unit_test.rs +4 -1
- data/lib/gte/config.rb +15 -0
- data/lib/gte/model.rb +35 -0
- data/lib/gte/reranker.rb +54 -0
- data/lib/gte/version.rb +5 -0
- data/lib/gte.rb +24 -35
- metadata +10 -2
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: ae83f737b57f798d39cf1fdc895d67948de27d36b46ea02c211a440d3acaa8c9
|
|
4
|
+
data.tar.gz: 9eaf9651b2ccf1fdb93efe4666ed70537628453a8cf92e234b454560560a83e8
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: a262194a53bf804e47b0ef9c5910c1e2b814a9824823a92a73867a631c7b26310b3163e61997d9c163dab402a40d49946b76a64cc0421741ae235f623180cb95
|
|
7
|
+
data.tar.gz: 6acf5b58140012df9fa25971ed0f1fdfa707cc3efbe5f7f22104e35ad57877778a08cf9f8b311017be8f40e255289e3249e35c1e3780ae231f9f66e08cbb6ac3
|
data/README.md
CHANGED
|
@@ -9,14 +9,105 @@ Inspired by https://github.com/fbilhaut/gte-rs
|
|
|
9
9
|
```ruby
|
|
10
10
|
require "gte"
|
|
11
11
|
|
|
12
|
-
model = GTE.
|
|
13
|
-
|
|
12
|
+
model = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
|
|
13
|
+
|
|
14
|
+
# String input => GTE::Tensor (1 row)
|
|
15
|
+
tensor = model.embed("query: hello world")
|
|
16
|
+
vector = tensor.row(0)
|
|
17
|
+
|
|
18
|
+
# [] with string => Array<Float> (single vector)
|
|
19
|
+
single = model["query: nearest coffee shop"]
|
|
20
|
+
|
|
21
|
+
# [] with array => GTE::Tensor (batch)
|
|
22
|
+
batch = model[["query: hello", "query: world"]]
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
## Embedding Config (`GTE.config`)
|
|
26
|
+
|
|
27
|
+
`GTE.config(model_dir)` builds (and caches) a `GTE::Model`.
|
|
28
|
+
|
|
29
|
+
```ruby
|
|
30
|
+
default_model = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
|
|
31
|
+
|
|
32
|
+
raw_model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
33
|
+
config.with(normalize: false)
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
full_throttle = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
37
|
+
config.with(threads: 0)
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
custom = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
41
|
+
config.with(
|
|
42
|
+
output_tensor: "last_hidden_state",
|
|
43
|
+
max_length: 256,
|
|
44
|
+
optimization_level: 3
|
|
45
|
+
)
|
|
46
|
+
end
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
Config fields and defaults:
|
|
50
|
+
|
|
51
|
+
- `model_dir`: absolute path to model directory
|
|
52
|
+
- `threads`: `3` (set `0` for ONNX Runtime full-throttle threadpool)
|
|
53
|
+
- `optimization_level`: `3`
|
|
54
|
+
- `model_name`: `nil`
|
|
55
|
+
- `normalize`: `true` (L2 normalization at Ruby-facing API)
|
|
56
|
+
- `output_tensor`: `nil` (auto-select output tensor)
|
|
57
|
+
- `max_length`: `nil` (uses tokenizer/model defaults)
|
|
58
|
+
|
|
59
|
+
Notes:
|
|
60
|
+
|
|
61
|
+
- Return a `Config::Text` from the block (for example, `config.with(...)`).
|
|
62
|
+
- Model instances are cached by full config key; different config values create different cached instances.
|
|
63
|
+
|
|
64
|
+
## Reranker
|
|
65
|
+
|
|
66
|
+
Use `GTE::Reranker.config(model_dir)` for cross-encoder reranking.
|
|
67
|
+
|
|
68
|
+
```ruby
|
|
69
|
+
reranker = GTE::Reranker.config(ENV.fetch("GTE_RERANK_DIR")) do |config|
|
|
70
|
+
config.with(sigmoid: true, threads: 0)
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
query = "how to train a neural network?"
|
|
74
|
+
candidates = [
|
|
75
|
+
"Backpropagation and gradient descent are core techniques.",
|
|
76
|
+
"This recipe uses flour and eggs."
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
# Raw scores aligned with input order
|
|
80
|
+
scores = reranker.score(query, candidates)
|
|
81
|
+
# => [0.93, 0.07]
|
|
82
|
+
|
|
83
|
+
# Ranked output sorted by score desc
|
|
84
|
+
ranked = reranker.rerank(query: query, candidates: candidates)
|
|
85
|
+
# => [
|
|
86
|
+
# { index: 0, score: 0.93, text: "Backpropagation and gradient descent are core techniques." },
|
|
87
|
+
# { index: 1, score: 0.07, text: "This recipe uses flour and eggs." }
|
|
88
|
+
# ]
|
|
14
89
|
```
|
|
15
90
|
|
|
16
|
-
|
|
91
|
+
Reranker config fields and defaults:
|
|
92
|
+
|
|
93
|
+
- `model_dir`: absolute path to model directory
|
|
94
|
+
- `threads`: `3`
|
|
95
|
+
- `optimization_level`: `3`
|
|
96
|
+
- `model_name`: `nil`
|
|
97
|
+
- `sigmoid`: `false` (set `true` if you want bounded [0,1] style scores)
|
|
98
|
+
- `output_tensor`: `nil`
|
|
99
|
+
- `max_length`: `nil`
|
|
100
|
+
|
|
101
|
+
## Runtime + Result Examples
|
|
102
|
+
|
|
103
|
+
Process-local reuse (recommended for Puma/web servers):
|
|
17
104
|
|
|
18
105
|
```ruby
|
|
19
|
-
|
|
106
|
+
EMBEDDER = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
|
|
107
|
+
|
|
108
|
+
def embed_query(text)
|
|
109
|
+
EMBEDDER[text] # Array<Float>
|
|
110
|
+
end
|
|
20
111
|
```
|
|
21
112
|
|
|
22
113
|
## Model Directory
|
|
@@ -28,14 +119,28 @@ A model directory must include `tokenizer.json` and one ONNX model, resolved in
|
|
|
28
119
|
3. `onnx/model.onnx`
|
|
29
120
|
4. `model.onnx`
|
|
30
121
|
|
|
122
|
+
Input policy is text-only. Graphs requiring unsupported multimodal inputs (such as `pixel_values`) are intentionally rejected.
|
|
123
|
+
|
|
124
|
+
## Execution Providers
|
|
125
|
+
|
|
126
|
+
Default execution provider is `xnnpack` on all platforms (including macOS arm64).
|
|
127
|
+
|
|
128
|
+
To opt in to CoreML explicitly:
|
|
129
|
+
|
|
130
|
+
```bash
|
|
131
|
+
export GTE_EXECUTION_PROVIDERS=xnnpack,coreml
|
|
132
|
+
```
|
|
133
|
+
|
|
31
134
|
## Development
|
|
32
135
|
|
|
33
|
-
Run commands inside `nix develop
|
|
136
|
+
Run commands inside `nix develop` via Make targets:
|
|
34
137
|
|
|
35
138
|
```bash
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
139
|
+
make setup
|
|
140
|
+
make compile
|
|
141
|
+
make test
|
|
142
|
+
make lint
|
|
143
|
+
make ci
|
|
39
144
|
```
|
|
40
145
|
|
|
41
146
|
## Benchmark
|
|
@@ -43,14 +148,14 @@ bundle exec rspec
|
|
|
43
148
|
The repo includes two benchmark paths:
|
|
44
149
|
|
|
45
150
|
```bash
|
|
46
|
-
|
|
47
|
-
bundle exec rake bench:
|
|
48
|
-
bundle exec rake bench:matrix_sweep
|
|
49
|
-
bundle exec ruby bench/memory_probe.rb --compare-pure
|
|
151
|
+
make bench
|
|
152
|
+
nix develop -c bundle exec rake bench:pure_compare
|
|
153
|
+
nix develop -c bundle exec rake bench:matrix_sweep
|
|
154
|
+
nix develop -c bundle exec ruby bench/memory_probe.rb --compare-pure
|
|
50
155
|
```
|
|
51
156
|
|
|
52
157
|
For release tracking and regression detection, record a run entry in `RUNS.md`:
|
|
53
158
|
|
|
54
159
|
```bash
|
|
55
|
-
|
|
160
|
+
make bench-record
|
|
56
161
|
```
|
data/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.0.
|
|
1
|
+
0.0.5
|
data/ext/gte/Cargo.toml
CHANGED
data/ext/gte/src/embedder.rs
CHANGED
|
@@ -1,19 +1,15 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
2
|
use crate::model_config::{ExtractorMode, ModelConfig};
|
|
3
|
+
use crate::model_profile::{
|
|
4
|
+
has_input, infer_extraction_mode, read_max_length, resolve_default_text_model, resolve_named_model,
|
|
5
|
+
resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
|
|
6
|
+
};
|
|
3
7
|
use crate::postprocess::normalize_l2 as normalize_l2_rows;
|
|
4
8
|
use crate::session::{build_session, run_session};
|
|
5
9
|
use crate::tokenizer::{Tokenized, Tokenizer};
|
|
6
10
|
use ndarray::Array2;
|
|
7
11
|
use ort::session::Session;
|
|
8
|
-
use std::path::
|
|
9
|
-
|
|
10
|
-
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
11
|
-
pub enum ModelFamily {
|
|
12
|
-
E5Like,
|
|
13
|
-
SiglipLike,
|
|
14
|
-
ClipLike,
|
|
15
|
-
Other,
|
|
16
|
-
}
|
|
12
|
+
use std::path::Path;
|
|
17
13
|
|
|
18
14
|
pub struct Embedder {
|
|
19
15
|
tokenizer: Tokenizer,
|
|
@@ -41,39 +37,50 @@ impl Embedder {
|
|
|
41
37
|
num_threads: usize,
|
|
42
38
|
optimization_level: u8,
|
|
43
39
|
model_name: Option<&str>,
|
|
40
|
+
output_tensor_override: Option<&str>,
|
|
41
|
+
max_length_override: Option<usize>,
|
|
44
42
|
) -> Result<Self> {
|
|
43
|
+
const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] = [
|
|
44
|
+
"pooler_output",
|
|
45
|
+
"text_embeds",
|
|
46
|
+
"sentence_embedding",
|
|
47
|
+
"last_hidden_state",
|
|
48
|
+
];
|
|
49
|
+
|
|
45
50
|
let dir = dir.as_ref();
|
|
46
|
-
let tokenizer_path = dir
|
|
51
|
+
let tokenizer_path = resolve_tokenizer_path(dir)?;
|
|
47
52
|
let model_path = match model_name.filter(|s| !s.is_empty()) {
|
|
48
53
|
Some(name) => resolve_named_model(dir, name)?,
|
|
49
|
-
None =>
|
|
54
|
+
None => resolve_default_text_model(dir)?,
|
|
50
55
|
};
|
|
51
56
|
|
|
52
|
-
if
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
57
|
+
let max_length = if let Some(override_value) = max_length_override {
|
|
58
|
+
if override_value == 0 {
|
|
59
|
+
return Err(GteError::Inference(
|
|
60
|
+
"max_length override must be greater than 0".to_string(),
|
|
61
|
+
));
|
|
62
|
+
}
|
|
63
|
+
override_value
|
|
64
|
+
} else {
|
|
65
|
+
read_max_length(dir)
|
|
66
|
+
};
|
|
58
67
|
|
|
59
|
-
let
|
|
60
|
-
let probe_num_threads = if num_threads == 0 { 1 } else { num_threads };
|
|
61
|
-
let temp_config = ModelConfig {
|
|
68
|
+
let session_config = ModelConfig {
|
|
62
69
|
max_length,
|
|
63
70
|
output_tensor: String::new(),
|
|
64
71
|
mode: ExtractorMode::Raw,
|
|
65
72
|
with_type_ids: false,
|
|
66
73
|
with_attention_mask: true,
|
|
67
|
-
num_threads
|
|
74
|
+
num_threads,
|
|
68
75
|
optimization_level,
|
|
69
76
|
};
|
|
70
|
-
let
|
|
77
|
+
let session = build_session(&model_path, &session_config)?;
|
|
71
78
|
|
|
72
|
-
|
|
73
|
-
let with_type_ids = session
|
|
74
|
-
let with_attention_mask = session
|
|
75
|
-
let output_tensor =
|
|
76
|
-
|
|
79
|
+
validate_supported_text_inputs(&session, "text embedding")?;
|
|
80
|
+
let with_type_ids = has_input(&session, "token_type_ids");
|
|
81
|
+
let with_attention_mask = has_input(&session, "attention_mask");
|
|
82
|
+
let output_tensor =
|
|
83
|
+
select_output_tensor(&session, output_tensor_override, &PREFERRED_EMBEDDING_OUTPUTS)?;
|
|
77
84
|
let mode = infer_extraction_mode(&session, output_tensor.as_str())?;
|
|
78
85
|
if matches!(mode, ExtractorMode::MeanPool) && !with_attention_mask {
|
|
79
86
|
return Err(GteError::Inference(
|
|
@@ -81,29 +88,16 @@ impl Embedder {
|
|
|
81
88
|
));
|
|
82
89
|
}
|
|
83
90
|
|
|
84
|
-
let tuned_num_threads = tune_num_threads(
|
|
85
|
-
num_threads,
|
|
86
|
-
with_attention_mask,
|
|
87
|
-
with_type_ids,
|
|
88
|
-
output_base.as_str(),
|
|
89
|
-
);
|
|
90
|
-
|
|
91
91
|
let config = ModelConfig {
|
|
92
92
|
max_length,
|
|
93
93
|
output_tensor,
|
|
94
94
|
mode,
|
|
95
95
|
with_type_ids,
|
|
96
96
|
with_attention_mask,
|
|
97
|
-
num_threads
|
|
97
|
+
num_threads,
|
|
98
98
|
optimization_level,
|
|
99
99
|
};
|
|
100
100
|
|
|
101
|
-
if tuned_num_threads != probe_num_threads {
|
|
102
|
-
// Release probe session before rebuilding to minimize transient peak RSS.
|
|
103
|
-
drop(session);
|
|
104
|
-
session = build_session(&model_path, &config)?;
|
|
105
|
-
}
|
|
106
|
-
|
|
107
101
|
let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
|
|
108
102
|
|
|
109
103
|
Ok(Self {
|
|
@@ -125,218 +119,6 @@ impl Embedder {
|
|
|
125
119
|
pub fn run(&self, tokenized: &Tokenized) -> crate::error::Result<Array2<f32>> {
|
|
126
120
|
run_session(&self.session, tokenized, &self.config)
|
|
127
121
|
}
|
|
128
|
-
|
|
129
|
-
}
|
|
130
|
-
|
|
131
|
-
fn tune_num_threads(
|
|
132
|
-
requested: usize,
|
|
133
|
-
with_attention_mask: bool,
|
|
134
|
-
with_type_ids: bool,
|
|
135
|
-
output_name: &str,
|
|
136
|
-
) -> usize {
|
|
137
|
-
if requested > 0 {
|
|
138
|
-
return requested;
|
|
139
|
-
}
|
|
140
|
-
|
|
141
|
-
let family = infer_model_family(with_attention_mask, with_type_ids, output_name);
|
|
142
|
-
|
|
143
|
-
match family {
|
|
144
|
-
// Puma-like workloads typically run many concurrent single-item requests where
|
|
145
|
-
// one intra-op thread per request gives the best tail behavior.
|
|
146
|
-
ModelFamily::E5Like | ModelFamily::ClipLike => 1,
|
|
147
|
-
// Siglip2 text path benefits from a small intra-op pool under concurrency.
|
|
148
|
-
ModelFamily::SiglipLike => 3,
|
|
149
|
-
ModelFamily::Other => 0,
|
|
150
|
-
}
|
|
151
|
-
}
|
|
152
|
-
|
|
153
|
-
fn infer_model_family(
|
|
154
|
-
with_attention_mask: bool,
|
|
155
|
-
with_type_ids: bool,
|
|
156
|
-
output_name: &str,
|
|
157
|
-
) -> ModelFamily {
|
|
158
|
-
if output_name == "last_hidden_state" && with_attention_mask && with_type_ids {
|
|
159
|
-
return ModelFamily::E5Like;
|
|
160
|
-
}
|
|
161
|
-
if output_name == "last_hidden_state" && with_attention_mask && !with_type_ids {
|
|
162
|
-
return ModelFamily::SiglipLike;
|
|
163
|
-
}
|
|
164
|
-
if output_name == "text_embeds" && !with_attention_mask {
|
|
165
|
-
return ModelFamily::ClipLike;
|
|
166
|
-
}
|
|
167
|
-
ModelFamily::Other
|
|
168
|
-
}
|
|
169
|
-
|
|
170
|
-
fn resolve_named_model(dir: &Path, name: &str) -> Result<PathBuf> {
|
|
171
|
-
let candidates = [dir.join("onnx").join(name), dir.join(name)];
|
|
172
|
-
for path in &candidates {
|
|
173
|
-
if path.exists() {
|
|
174
|
-
return Ok(path.clone());
|
|
175
|
-
}
|
|
176
|
-
}
|
|
177
|
-
Err(GteError::Inference(format!(
|
|
178
|
-
"model '{}' not found in {} (checked onnx/{0} and {0})",
|
|
179
|
-
name,
|
|
180
|
-
dir.display()
|
|
181
|
-
)))
|
|
182
|
-
}
|
|
183
|
-
|
|
184
|
-
fn resolve_model_path(dir: &Path) -> Result<PathBuf> {
|
|
185
|
-
let candidates = [
|
|
186
|
-
dir.join("onnx").join("text_model.onnx"),
|
|
187
|
-
dir.join("text_model.onnx"),
|
|
188
|
-
dir.join("onnx").join("model.onnx"),
|
|
189
|
-
dir.join("model.onnx"),
|
|
190
|
-
];
|
|
191
|
-
for path in &candidates {
|
|
192
|
-
if path.exists() {
|
|
193
|
-
return Ok(path.clone());
|
|
194
|
-
}
|
|
195
|
-
}
|
|
196
|
-
Err(GteError::Inference(format!(
|
|
197
|
-
"no ONNX model found in {} (checked text_model.onnx and model.onnx)",
|
|
198
|
-
dir.display()
|
|
199
|
-
)))
|
|
200
|
-
}
|
|
201
|
-
|
|
202
|
-
const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
|
|
203
|
-
|
|
204
|
-
fn validate_supported_inputs(session: &Session) -> Result<()> {
|
|
205
|
-
let unsupported: Vec<String> = session
|
|
206
|
-
.inputs
|
|
207
|
-
.iter()
|
|
208
|
-
.filter(|i| !SUPPORTED_INPUTS.contains(&i.name.as_str()))
|
|
209
|
-
.map(|i| i.name.clone())
|
|
210
|
-
.collect();
|
|
211
|
-
|
|
212
|
-
if unsupported.is_empty() {
|
|
213
|
-
return Ok(());
|
|
214
|
-
}
|
|
215
|
-
|
|
216
|
-
let mut message = format!(
|
|
217
|
-
"unsupported model inputs for text embedding API: {}",
|
|
218
|
-
unsupported.join(", ")
|
|
219
|
-
);
|
|
220
|
-
if unsupported.iter().any(|n| n == "pixel_values") {
|
|
221
|
-
message.push_str(
|
|
222
|
-
". This looks like a multimodal graph. Provide a text-only export (for example onnx/text_model.onnx).",
|
|
223
|
-
);
|
|
224
|
-
} else {
|
|
225
|
-
message.push_str(". Supported inputs are: input_ids, attention_mask, token_type_ids.");
|
|
226
|
-
}
|
|
227
|
-
Err(GteError::Inference(message))
|
|
228
|
-
}
|
|
229
|
-
|
|
230
|
-
fn output_name_matches(name: &str, preferred: &str) -> bool {
|
|
231
|
-
let lower = name.to_ascii_lowercase();
|
|
232
|
-
lower == preferred || lower.ends_with(&format!("/{}", preferred))
|
|
233
|
-
}
|
|
234
|
-
|
|
235
|
-
fn select_output_tensor(session: &Session) -> Result<String> {
|
|
236
|
-
const PREFERRED: [&str; 4] = [
|
|
237
|
-
"text_embeds",
|
|
238
|
-
"pooler_output",
|
|
239
|
-
"sentence_embedding",
|
|
240
|
-
"last_hidden_state",
|
|
241
|
-
];
|
|
242
|
-
|
|
243
|
-
for preferred in PREFERRED {
|
|
244
|
-
if let Some(output) = session
|
|
245
|
-
.outputs
|
|
246
|
-
.iter()
|
|
247
|
-
.find(|o| output_name_matches(o.name.as_str(), preferred))
|
|
248
|
-
{
|
|
249
|
-
return Ok(output.name.clone());
|
|
250
|
-
}
|
|
251
|
-
}
|
|
252
|
-
|
|
253
|
-
session
|
|
254
|
-
.outputs
|
|
255
|
-
.first()
|
|
256
|
-
.map(|o| o.name.clone())
|
|
257
|
-
.ok_or_else(|| GteError::Inference("model has no outputs".into()))
|
|
258
|
-
}
|
|
259
|
-
|
|
260
|
-
fn read_max_length(dir: &Path) -> usize {
|
|
261
|
-
(|| -> Option<usize> {
|
|
262
|
-
let contents = std::fs::read_to_string(dir.join("tokenizer_config.json")).ok()?;
|
|
263
|
-
let json: serde_json::Value = serde_json::from_str(&contents).ok()?;
|
|
264
|
-
let v = json.get("model_max_length")?;
|
|
265
|
-
let n = v
|
|
266
|
-
.as_u64()
|
|
267
|
-
.or_else(|| v.as_f64().filter(|&f| f > 0.0 && f < 1e15).map(|f| f as u64))?;
|
|
268
|
-
Some((n as usize).min(8192))
|
|
269
|
-
})()
|
|
270
|
-
.unwrap_or(512)
|
|
271
|
-
}
|
|
272
|
-
|
|
273
|
-
#[cfg(test)]
|
|
274
|
-
mod tests {
|
|
275
|
-
use super::{infer_model_family, tune_num_threads, ModelFamily};
|
|
276
|
-
|
|
277
|
-
#[test]
|
|
278
|
-
fn infer_model_family_recognizes_known_signatures() {
|
|
279
|
-
assert_eq!(
|
|
280
|
-
infer_model_family(true, true, "last_hidden_state"),
|
|
281
|
-
ModelFamily::E5Like
|
|
282
|
-
);
|
|
283
|
-
assert_eq!(
|
|
284
|
-
infer_model_family(true, false, "last_hidden_state"),
|
|
285
|
-
ModelFamily::SiglipLike
|
|
286
|
-
);
|
|
287
|
-
assert_eq!(
|
|
288
|
-
infer_model_family(false, false, "text_embeds"),
|
|
289
|
-
ModelFamily::ClipLike
|
|
290
|
-
);
|
|
291
|
-
assert_eq!(infer_model_family(true, false, "pooler_output"), ModelFamily::Other);
|
|
292
|
-
}
|
|
293
|
-
|
|
294
|
-
#[test]
|
|
295
|
-
fn tune_num_threads_respects_requested_value() {
|
|
296
|
-
assert_eq!(tune_num_threads(7, true, true, "last_hidden_state"), 7);
|
|
297
|
-
}
|
|
298
|
-
|
|
299
|
-
#[test]
|
|
300
|
-
fn tune_num_threads_returns_ort_default_for_other_family() {
|
|
301
|
-
assert_eq!(tune_num_threads(0, true, false, "pooler_output"), 0);
|
|
302
|
-
}
|
|
303
|
-
}
|
|
304
|
-
|
|
305
|
-
fn output_basename(name: &str) -> &str {
|
|
306
|
-
name.rsplit('/').next().unwrap_or(name)
|
|
307
|
-
}
|
|
308
|
-
|
|
309
|
-
fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
|
|
310
|
-
let output = session
|
|
311
|
-
.outputs
|
|
312
|
-
.iter()
|
|
313
|
-
.find(|o| o.name == output_tensor)
|
|
314
|
-
.ok_or_else(|| {
|
|
315
|
-
GteError::Inference(format!(
|
|
316
|
-
"output tensor '{}' not found in model outputs",
|
|
317
|
-
output_tensor
|
|
318
|
-
))
|
|
319
|
-
})?;
|
|
320
|
-
|
|
321
|
-
let ndims = match &output.output_type {
|
|
322
|
-
ort::value::ValueType::Tensor { dimensions, .. } => dimensions.len(),
|
|
323
|
-
other => {
|
|
324
|
-
return Err(GteError::Inference(format!(
|
|
325
|
-
"output is not a tensor: {:?}",
|
|
326
|
-
other
|
|
327
|
-
)))
|
|
328
|
-
}
|
|
329
|
-
};
|
|
330
|
-
|
|
331
|
-
match (output_basename(output_tensor), ndims) {
|
|
332
|
-
("last_hidden_state", 3) => Ok(ExtractorMode::MeanPool),
|
|
333
|
-
(_, 2) => Ok(ExtractorMode::Raw),
|
|
334
|
-
(_, 3) => Ok(ExtractorMode::MeanPool),
|
|
335
|
-
(_, n) => Err(GteError::Inference(format!(
|
|
336
|
-
"unexpected output tensor rank {} for '{}': expected 2 (Raw) or 3 (MeanPool)",
|
|
337
|
-
n, output_tensor
|
|
338
|
-
))),
|
|
339
|
-
}
|
|
340
122
|
}
|
|
341
123
|
|
|
342
124
|
pub fn normalize_l2(embeddings: Array2<f32>) -> Array2<f32> {
|