gte 0.0.13 → 0.0.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +93 -27
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +26 -4
- data/ext/gte/benches/hot_path.rs +20 -54
- data/ext/gte/build.rs +2 -6
- data/ext/gte/rustfmt.toml +5 -0
- data/ext/gte/src/embedder.rs +71 -43
- data/ext/gte/src/error.rs +4 -4
- data/ext/gte/src/lib.rs +1 -1
- data/ext/gte/src/model_config.rs +4 -0
- data/ext/gte/src/model_profile.rs +26 -87
- data/ext/gte/src/pipeline.rs +11 -30
- data/ext/gte/src/postprocess.rs +8 -14
- data/ext/gte/src/reranker.rs +50 -50
- data/ext/gte/src/ruby_embedder.rs +48 -53
- data/ext/gte/src/session.rs +136 -248
- data/ext/gte/src/tokenizer.rs +51 -125
- data/ext/gte/tests/inference_integration_test.rs +8 -18
- data/ext/gte/tests/padding_regression_test.rs +13 -26
- data/ext/gte/tests/tokenizer_unit_test.rs +10 -24
- data/lib/gte/config.rb +2 -1
- data/lib/gte/embedder.rb +6 -2
- data/lib/gte/reranker.rb +3 -1
- data/lib/gte.rb +6 -0
- metadata +2 -1
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 37ad2ef3f640b8bbaefa14cae29541f329c3b99950a2867b9054b8e8854ca242
|
|
4
|
+
data.tar.gz: 0b54c757ca510ccc8644e0d3c1519aa7b6576fa1da4ba9012535e0b0b7d598dc
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 5d67fc8c73aa2b162bf804082accc849ce972b4df51e2ec86bb7d977cd760e2bf5b760d7de141b5fbd5b6f68dd0e51d2bd15c84d1cb8dc9a8f4bf424b62490ba
|
|
7
|
+
data.tar.gz: c6cefec4d42f7ca72980e4d3454447aa681f05f0b72f96ceec226b5efd5ae32bade13d85ae55f032098d7cb7f72b530d8332f1eaf25fa364d0c33703da13e043
|
data/README.md
CHANGED
|
@@ -58,6 +58,8 @@ Notes:
|
|
|
58
58
|
|
|
59
59
|
- Return a `Config::Text` from the block (for example, `config.with(...)`).
|
|
60
60
|
- Model instances are cached by full config key; different config values create different cached instances.
|
|
61
|
+
- `GTE.warmup(model, threads:)` pre-warms thread-local ONNX sessions eagerly at boot.
|
|
62
|
+
Useful in multi-threaded servers (Puma, Sidekiq) to avoid ~100-500ms cold-start latency.
|
|
61
63
|
|
|
62
64
|
Common model presets:
|
|
63
65
|
|
|
@@ -73,7 +75,7 @@ end
|
|
|
73
75
|
|
|
74
76
|
siglip2 = GTE.config(ENV.fetch("GTE_SIGLIP2_DIR")) do |config|
|
|
75
77
|
config.with(
|
|
76
|
-
model_name: "
|
|
78
|
+
model_name: "text_model.onnx",
|
|
77
79
|
output_tensor: "pooler_output",
|
|
78
80
|
max_length: 64,
|
|
79
81
|
execution_providers: "cpu"
|
|
@@ -147,6 +149,53 @@ Session pool sizing:
|
|
|
147
149
|
- `GTE_SESSION_POOL_CAP`: optional positive integer cap for internal ONNX session pool size.
|
|
148
150
|
- Unset by default; runtime uses available CPU parallelism.
|
|
149
151
|
|
|
152
|
+
## Automatic Tuning
|
|
153
|
+
|
|
154
|
+
`gte` automatically adapts to the hardware — no configuration required.
|
|
155
|
+
|
|
156
|
+
### ONNX Intra-op Threads
|
|
157
|
+
|
|
158
|
+
- Auto-detected via `std::thread::available_parallelism()` capped at 4.
|
|
159
|
+
- Prevents oversubscription on high-concurrency workloads.
|
|
160
|
+
- Override with `GTE_INTRA_OP_NUM_THREADS` env var.
|
|
161
|
+
|
|
162
|
+
### ONNX Inter-op Threads
|
|
163
|
+
|
|
164
|
+
- Defaults to 1 (text embedding graphs are linear chains with no independent parallel nodes).
|
|
165
|
+
- Override with `GTE_INTER_OP_NUM_THREADS` env var.
|
|
166
|
+
|
|
167
|
+
### Execution Providers
|
|
168
|
+
|
|
169
|
+
`gte` automatically tries XNNPACK for optimized CPU inference. Falls back to
|
|
170
|
+
ORT's default CPU provider if unavailable.
|
|
171
|
+
|
|
172
|
+
- **ARM64** (Apple Silicon, AWS Graviton): XNNPACK is typically **~25% faster**
|
|
173
|
+
than plain CPU while producing identical embeddings (cos=1.0, max_abs=0.0).
|
|
174
|
+
- **x86/x64** (Intel, AMD): XNNPACK offers minimal benefit — ORT's default CPU
|
|
175
|
+
provider already uses MKL-DNN/oneDNN, which are better tuned for these chips.
|
|
176
|
+
The auto-detect silently falls back to the default provider.
|
|
177
|
+
|
|
178
|
+
Configure providers explicitly with `GTE_EXECUTION_PROVIDERS` (comma-separated):
|
|
179
|
+
|
|
180
|
+
```bash
|
|
181
|
+
export GTE_EXECUTION_PROVIDERS=xnnpack,coreml
|
|
182
|
+
```
|
|
183
|
+
|
|
184
|
+
Set `cpu` or `none` to skip auto-detect and use ORT's default CPU provider.
|
|
185
|
+
|
|
186
|
+
### Session Pre-Warming
|
|
187
|
+
|
|
188
|
+
ONNX sessions are created lazily per OS thread. In multi-threaded servers (Puma, Sidekiq),
|
|
189
|
+
each thread creates its own session on first use (~100-500ms cold start).
|
|
190
|
+
Pre-warm sessions eagerly at boot:
|
|
191
|
+
|
|
192
|
+
```ruby
|
|
193
|
+
model = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
|
|
194
|
+
|
|
195
|
+
# Pre-warm thread-local sessions for a Puma server with 5 threads:
|
|
196
|
+
GTE.warmup(model, threads: 5)
|
|
197
|
+
```
|
|
198
|
+
|
|
150
199
|
## Runtime + Result Examples
|
|
151
200
|
|
|
152
201
|
Process-local reuse (recommended for Puma/web servers):
|
|
@@ -170,47 +219,64 @@ A model directory must include `tokenizer.json` and one ONNX model, resolved in
|
|
|
170
219
|
|
|
171
220
|
Input policy is text-only. Graphs requiring unsupported multimodal inputs (such as `pixel_values`) are intentionally rejected.
|
|
172
221
|
|
|
173
|
-
## Execution Providers
|
|
174
222
|
|
|
175
|
-
|
|
223
|
+
## Development
|
|
224
|
+
|
|
225
|
+
Run commands inside `nix develop` via Make targets:
|
|
226
|
+
|
|
227
|
+
```bash
|
|
228
|
+
make setup
|
|
229
|
+
make compile
|
|
230
|
+
make test
|
|
231
|
+
make lint
|
|
232
|
+
make ci
|
|
233
|
+
```
|
|
234
|
+
|
|
235
|
+
## Benchmarks
|
|
176
236
|
|
|
177
|
-
|
|
178
|
-
Supported values:
|
|
237
|
+
### Docker Rails+Puma+wrk (Real-World HTTP)
|
|
179
238
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
- `coreml`
|
|
239
|
+
The `bench/rails/` directory contains a full-stack benchmark: Rails 7.1 API app served by Puma,
|
|
240
|
+
loaded with wrk (randomized text queries, 135 diverse texts).
|
|
183
241
|
|
|
184
|
-
|
|
242
|
+
Run for all models:
|
|
185
243
|
|
|
186
244
|
```bash
|
|
187
|
-
|
|
188
|
-
export GTE_EXECUTION_PROVIDERS=xnnpack,coreml
|
|
245
|
+
make bench-docker-compare
|
|
189
246
|
```
|
|
190
247
|
|
|
191
|
-
|
|
248
|
+
Run for a single model:
|
|
192
249
|
|
|
193
|
-
```
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
end
|
|
250
|
+
```bash
|
|
251
|
+
make bench-docker-sweep-siglip2
|
|
252
|
+
make bench-docker-validate # cross-validation checks
|
|
197
253
|
```
|
|
198
254
|
|
|
199
|
-
|
|
255
|
+
#### Siglip2 (768-dim, pooler_output)
|
|
200
256
|
|
|
201
|
-
|
|
257
|
+
| Concurrency | GTE p90 | Pure Ruby p90 | Ratio | GTE RPS | Pure Ruby RPS |
|
|
258
|
+
|------------|---------|---------------|-------|---------|---------------|
|
|
259
|
+
| c=1 | ~12ms | ~120ms | 9-10× | ~95 | ~10 |
|
|
260
|
+
| c=4 | ~39ms | ~503ms | 10-13× | ~228 | ~10 |
|
|
261
|
+
| c=8 | ~146ms | ~613ms | 3-4× | ~224 | ~10 |
|
|
262
|
+
| c=16 | ~430ms | ~611ms | 1-1.5× | ~226 | ~11 |
|
|
202
263
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
264
|
+
#### E5 (384-dim, last_hidden_state + mean pool)
|
|
265
|
+
|
|
266
|
+
| Concurrency | GTE p90 | Pure Ruby p90 | Ratio | GTE RPS | Pure Ruby RPS |
|
|
267
|
+
|------------|---------|---------------|-------|---------|---------------|
|
|
268
|
+
| c=1 | ~7ms | ~120ms | 16-17× | ~160 | ~10 |
|
|
269
|
+
| c=4 | ~12ms | ~430ms | 35-40× | ~477 | ~10 |
|
|
270
|
+
| c=8 | ~64ms | ~530ms | 8-9× | ~503 | ~10 |
|
|
271
|
+
| c=16 | ~205ms | ~534ms | 2-3× | ~509 | ~11 |
|
|
272
|
+
|
|
273
|
+
GTE releases the GVL during ONNX inference, enabling true parallelism across Puma threads.
|
|
274
|
+
Pure Ruby is GVL-bound (~10 RPS regardless of concurrency).
|
|
210
275
|
|
|
211
|
-
|
|
276
|
+
The Puma thread pool (min=2, max=5) limits throughput at c=16+.
|
|
277
|
+
GTE's pipelining and GVL release already saturate the available threads at c=4.
|
|
212
278
|
|
|
213
|
-
|
|
279
|
+
### In-Process Benchmarks
|
|
214
280
|
|
|
215
281
|
```bash
|
|
216
282
|
make bench
|
data/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.0.
|
|
1
|
+
0.0.14
|
data/ext/gte/Cargo.toml
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[package]
|
|
2
2
|
name = "gte"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.14"
|
|
4
4
|
edition = "2021"
|
|
5
5
|
authors = ["elcuervo <elcuervo@elcuervo.net>"]
|
|
6
6
|
license = "MIT"
|
|
@@ -22,11 +22,8 @@ ruby-ffi = ["dep:magnus", "dep:rb-sys"]
|
|
|
22
22
|
rb-sys = { version = "0.9", features = ["stable-api-compiled-fallback"], optional = true }
|
|
23
23
|
magnus = { version = "0.8", optional = true }
|
|
24
24
|
ort = { version = "=2.0.0-rc.12", features = ["ndarray", "xnnpack"] }
|
|
25
|
-
ort-sys = "=2.0.0-rc.12"
|
|
26
25
|
tokenizers = "0.21.0"
|
|
27
26
|
ndarray = "0.17"
|
|
28
|
-
half = "2"
|
|
29
|
-
serde = { version = "1", features = ["derive"] }
|
|
30
27
|
serde_json = "1"
|
|
31
28
|
|
|
32
29
|
[dev-dependencies]
|
|
@@ -35,3 +32,28 @@ criterion = "0.5"
|
|
|
35
32
|
[[bench]]
|
|
36
33
|
name = "hot_path"
|
|
37
34
|
harness = false
|
|
35
|
+
|
|
36
|
+
[lints.rust]
|
|
37
|
+
unsafe_code = "deny"
|
|
38
|
+
rust_2018_idioms = "warn"
|
|
39
|
+
unused_qualifications = "warn"
|
|
40
|
+
unused_results = "warn"
|
|
41
|
+
|
|
42
|
+
[lints.clippy]
|
|
43
|
+
all = "warn"
|
|
44
|
+
pedantic = "warn"
|
|
45
|
+
# Reasonable exceptions to pedantic lints:
|
|
46
|
+
module_name_repetitions = "allow"
|
|
47
|
+
missing_errors_doc = "allow"
|
|
48
|
+
must_use_candidate = "allow"
|
|
49
|
+
cast_possible_truncation = "allow"
|
|
50
|
+
cast_sign_loss = "allow"
|
|
51
|
+
cast_precision_loss = "allow"
|
|
52
|
+
similar_names = "allow"
|
|
53
|
+
too_many_lines = "allow"
|
|
54
|
+
# ndarray::ArrayView types are Copy — passing by value is idiomatic
|
|
55
|
+
needless_pass_by_value = "allow"
|
|
56
|
+
# ort::Outlet::name() is on a private type — closure is required
|
|
57
|
+
redundant_closure_for_method_calls = "allow"
|
|
58
|
+
# Transitive dep conflicts are not actionable in this crate
|
|
59
|
+
multiple_crate_versions = "allow"
|
data/ext/gte/benches/hot_path.rs
CHANGED
|
@@ -5,9 +5,7 @@ use gte::postprocess::{mean_pool, normalize_l2};
|
|
|
5
5
|
use ndarray::{Array2, Array3};
|
|
6
6
|
|
|
7
7
|
fn build_hidden_states(batch: usize, seq: usize, dim: usize) -> Array3<f32> {
|
|
8
|
-
Array3::from_shape_fn((batch, seq, dim), |(b, s, d)|
|
|
9
|
-
(((b * 31 + s * 17 + d * 13) % 97) as f32) / 97.0
|
|
10
|
-
})
|
|
8
|
+
Array3::from_shape_fn((batch, seq, dim), |(b, s, d)| (((b * 31 + s * 17 + d * 13) % 97) as f32) / 97.0)
|
|
11
9
|
}
|
|
12
10
|
|
|
13
11
|
fn build_attention_mask(batch: usize, seq: usize) -> Array2<i64> {
|
|
@@ -22,15 +20,7 @@ fn bench_mean_pool(c: &mut Criterion) {
|
|
|
22
20
|
group.bench_with_input(
|
|
23
21
|
BenchmarkId::from_parameter(format!("{batch}x{seq}x{dim}")),
|
|
24
22
|
&(batch, seq, dim),
|
|
25
|
-
|b, _|
|
|
26
|
-
b.iter(|| {
|
|
27
|
-
mean_pool(
|
|
28
|
-
black_box(hidden_states.view()),
|
|
29
|
-
black_box(attention_mask.view()),
|
|
30
|
-
)
|
|
31
|
-
.unwrap()
|
|
32
|
-
})
|
|
33
|
-
},
|
|
23
|
+
|b, _| b.iter(|| mean_pool(black_box(hidden_states.view()), black_box(attention_mask.view())).unwrap()),
|
|
34
24
|
);
|
|
35
25
|
}
|
|
36
26
|
group.finish();
|
|
@@ -39,14 +29,10 @@ fn bench_mean_pool(c: &mut Criterion) {
|
|
|
39
29
|
fn bench_normalize_l2(c: &mut Criterion) {
|
|
40
30
|
let mut group = c.benchmark_group("normalize_l2");
|
|
41
31
|
for (rows, dim) in [(1, 384), (8, 384), (32, 768), (128, 768)] {
|
|
42
|
-
let embeddings = Array2::from_shape_fn((rows, dim), |(row, col)|
|
|
43
|
-
|
|
32
|
+
let embeddings = Array2::from_shape_fn((rows, dim), |(row, col)| (((row * 19 + col * 7) % 113) as f32) / 113.0);
|
|
33
|
+
group.bench_with_input(BenchmarkId::from_parameter(format!("{rows}x{dim}")), &(rows, dim), |b, _| {
|
|
34
|
+
b.iter(|| normalize_l2(black_box(embeddings.clone())))
|
|
44
35
|
});
|
|
45
|
-
group.bench_with_input(
|
|
46
|
-
BenchmarkId::from_parameter(format!("{rows}x{dim}")),
|
|
47
|
-
&(rows, dim),
|
|
48
|
-
|b, _| b.iter(|| normalize_l2(black_box(embeddings.clone()))),
|
|
49
|
-
);
|
|
50
36
|
}
|
|
51
37
|
group.finish();
|
|
52
38
|
}
|
|
@@ -61,26 +47,14 @@ fn bench_padding_impact(c: &mut Criterion) {
|
|
|
61
47
|
let dim = 768;
|
|
62
48
|
let mut group = c.benchmark_group("padding_impact");
|
|
63
49
|
|
|
64
|
-
for (label, seq) in
|
|
65
|
-
("batch_longest/4tok", 4usize),
|
|
66
|
-
|
|
67
|
-
("fixed/e5_max_512", 512usize),
|
|
68
|
-
] {
|
|
50
|
+
for (label, seq) in
|
|
51
|
+
[("batch_longest/4tok", 4usize), ("fixed/siglip2_max_64", 64usize), ("fixed/e5_max_512", 512usize)]
|
|
52
|
+
{
|
|
69
53
|
let hidden_states = build_hidden_states(1, seq, dim);
|
|
70
54
|
let attention_mask = build_attention_mask(1, seq);
|
|
71
|
-
group.bench_with_input(
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|b, _| {
|
|
75
|
-
b.iter(|| {
|
|
76
|
-
mean_pool(
|
|
77
|
-
black_box(hidden_states.view()),
|
|
78
|
-
black_box(attention_mask.view()),
|
|
79
|
-
)
|
|
80
|
-
.unwrap()
|
|
81
|
-
})
|
|
82
|
-
},
|
|
83
|
-
);
|
|
55
|
+
group.bench_with_input(BenchmarkId::from_parameter(label), &seq, |b, _| {
|
|
56
|
+
b.iter(|| mean_pool(black_box(hidden_states.view()), black_box(attention_mask.view())).unwrap())
|
|
57
|
+
});
|
|
84
58
|
}
|
|
85
59
|
group.finish();
|
|
86
60
|
}
|
|
@@ -93,7 +67,12 @@ fn bench_padding_impact(c: &mut Criterion) {
|
|
|
93
67
|
// Sweeps execution providers for quick local comparison.
|
|
94
68
|
fn bench_embedding_e2e(c: &mut Criterion) {
|
|
95
69
|
let cases = [
|
|
96
|
-
(
|
|
70
|
+
(
|
|
71
|
+
"e5",
|
|
72
|
+
"GTE_BENCH_E5_DIR",
|
|
73
|
+
"query: cat",
|
|
74
|
+
"query: ".to_string() + &"the quick brown fox jumps over the lazy dog ".repeat(20),
|
|
75
|
+
),
|
|
97
76
|
("siglip2", "GTE_BENCH_SIGLIP2_DIR", "cat", "a photo of ".to_string() + &"a cat sitting on a mat ".repeat(10)),
|
|
98
77
|
("clip", "GTE_BENCH_CLIP_DIR", "cat", "a photo of ".to_string() + &"a cat sitting on a mat ".repeat(10)),
|
|
99
78
|
];
|
|
@@ -107,10 +86,7 @@ fn bench_embedding_e2e(c: &mut Criterion) {
|
|
|
107
86
|
};
|
|
108
87
|
|
|
109
88
|
for provider in ["cpu", "xnnpack"] {
|
|
110
|
-
let overrides = ModelLoadOverrides {
|
|
111
|
-
execution_providers: Some(provider),
|
|
112
|
-
..ModelLoadOverrides::default()
|
|
113
|
-
};
|
|
89
|
+
let overrides = ModelLoadOverrides { execution_providers: Some(provider), ..ModelLoadOverrides::default() };
|
|
114
90
|
let embedder = match Embedder::from_dir(&dir, 3, overrides) {
|
|
115
91
|
Ok(e) => e,
|
|
116
92
|
Err(err) => {
|
|
@@ -122,11 +98,7 @@ fn bench_embedding_e2e(c: &mut Criterion) {
|
|
|
122
98
|
for (input_label, input) in [("short", short_input.to_string()), ("long", long_input.clone())] {
|
|
123
99
|
let id = BenchmarkId::from_parameter(format!("{model_label}/{provider}/{input_label}"));
|
|
124
100
|
group.bench_with_input(id, &input, |b, text| {
|
|
125
|
-
b.iter(||
|
|
126
|
-
embedder
|
|
127
|
-
.embed(black_box(vec![text.clone()]))
|
|
128
|
-
.expect("embed succeeds")
|
|
129
|
-
})
|
|
101
|
+
b.iter(|| embedder.embed(black_box(&[text.clone()])).expect("embed succeeds"))
|
|
130
102
|
});
|
|
131
103
|
}
|
|
132
104
|
}
|
|
@@ -134,11 +106,5 @@ fn bench_embedding_e2e(c: &mut Criterion) {
|
|
|
134
106
|
group.finish();
|
|
135
107
|
}
|
|
136
108
|
|
|
137
|
-
criterion_group!(
|
|
138
|
-
benches,
|
|
139
|
-
bench_mean_pool,
|
|
140
|
-
bench_normalize_l2,
|
|
141
|
-
bench_padding_impact,
|
|
142
|
-
bench_embedding_e2e
|
|
143
|
-
);
|
|
109
|
+
criterion_group!(benches, bench_mean_pool, bench_normalize_l2, bench_padding_impact, bench_embedding_e2e);
|
|
144
110
|
criterion_main!(benches);
|
data/ext/gte/build.rs
CHANGED
|
@@ -1,15 +1,11 @@
|
|
|
1
1
|
fn main() {
|
|
2
|
-
let version = std::fs::read_to_string("../../VERSION")
|
|
3
|
-
.expect("VERSION file not found")
|
|
4
|
-
.trim()
|
|
5
|
-
.to_string();
|
|
2
|
+
let version = std::fs::read_to_string("../../VERSION").expect("VERSION file not found").trim().to_string();
|
|
6
3
|
|
|
7
4
|
let cargo_version = env!("CARGO_PKG_VERSION");
|
|
8
5
|
|
|
9
6
|
assert_eq!(
|
|
10
7
|
version, cargo_version,
|
|
11
|
-
"VERSION file ({}) doesn't match Cargo.toml ({}). Update Cargo.toml to match.",
|
|
12
|
-
version, cargo_version
|
|
8
|
+
"VERSION file ({version}) doesn't match Cargo.toml ({cargo_version}). Update Cargo.toml to match.",
|
|
13
9
|
);
|
|
14
10
|
|
|
15
11
|
println!("cargo:rerun-if-changed=../../VERSION");
|
data/ext/gte/src/embedder.rs
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
2
|
use crate::model_config::{ExtractorMode, ModelConfig, ModelLoadOverrides, PaddingMode};
|
|
3
3
|
use crate::model_profile::{
|
|
4
|
-
has_input, infer_extraction_mode, read_tokenizer_profile, resolve_default_text_model,
|
|
5
|
-
|
|
4
|
+
has_input, infer_extraction_mode, read_tokenizer_profile, resolve_default_text_model, resolve_named_model,
|
|
5
|
+
resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
|
|
6
6
|
};
|
|
7
7
|
use crate::postprocess::normalize_l2 as normalize_l2_rows;
|
|
8
|
-
use crate::session::{build_session,
|
|
8
|
+
use crate::session::{build_session, SessionPool};
|
|
9
9
|
use crate::tokenizer::{parse_padding_mode_override, Tokenized, Tokenizer};
|
|
10
10
|
use ndarray::Array2;
|
|
11
11
|
use std::path::{Path, PathBuf};
|
|
@@ -13,7 +13,7 @@ use std::path::{Path, PathBuf};
|
|
|
13
13
|
pub struct Embedder {
|
|
14
14
|
tokenizer: Tokenizer,
|
|
15
15
|
pool: SessionPool,
|
|
16
|
-
config: ModelConfig,
|
|
16
|
+
pub config: ModelConfig,
|
|
17
17
|
}
|
|
18
18
|
|
|
19
19
|
impl Embedder {
|
|
@@ -22,30 +22,17 @@ impl Embedder {
|
|
|
22
22
|
P1: AsRef<Path>,
|
|
23
23
|
P2: AsRef<Path>,
|
|
24
24
|
{
|
|
25
|
-
let tokenizer =
|
|
26
|
-
tokenizer_path,
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
None,
|
|
31
|
-
)?;
|
|
32
|
-
let model_path = model_path.as_ref().to_path_buf();
|
|
33
|
-
let session = build_session(&model_path, &config)?;
|
|
34
|
-
let pool = SessionPool::new(session, model_path, config.clone());
|
|
25
|
+
let tokenizer =
|
|
26
|
+
Tokenizer::new(tokenizer_path, config.max_length, config.with_type_ids, config.padding_mode, None)?;
|
|
27
|
+
let model_path = model_path.as_ref();
|
|
28
|
+
let session = build_session(model_path, &config)?;
|
|
29
|
+
let pool = SessionPool::new(session, model_path, &config)?;
|
|
35
30
|
Ok(Self { tokenizer, pool, config })
|
|
36
31
|
}
|
|
37
32
|
|
|
38
|
-
pub fn from_dir<P: AsRef<Path>>(
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
overrides: ModelLoadOverrides<'_>,
|
|
42
|
-
) -> Result<Self> {
|
|
43
|
-
const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] = [
|
|
44
|
-
"pooler_output",
|
|
45
|
-
"text_embeds",
|
|
46
|
-
"sentence_embedding",
|
|
47
|
-
"last_hidden_state",
|
|
48
|
-
];
|
|
33
|
+
pub fn from_dir<P: AsRef<Path>>(dir: P, optimization_level: u8, overrides: ModelLoadOverrides<'_>) -> Result<Self> {
|
|
34
|
+
const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] =
|
|
35
|
+
["pooler_output", "text_embeds", "sentence_embedding", "last_hidden_state"];
|
|
49
36
|
|
|
50
37
|
let dir = dir.as_ref();
|
|
51
38
|
let tokenizer_path = resolve_tokenizer_path(dir)?;
|
|
@@ -57,16 +44,13 @@ impl Embedder {
|
|
|
57
44
|
let tokenizer_profile = read_tokenizer_profile(dir);
|
|
58
45
|
let max_length = if let Some(override_value) = overrides.max_length {
|
|
59
46
|
if override_value == 0 {
|
|
60
|
-
return Err(GteError::Inference(
|
|
61
|
-
"max_length override must be greater than 0".to_string(),
|
|
62
|
-
));
|
|
47
|
+
return Err(GteError::Inference("max_length override must be greater than 0".to_string()));
|
|
63
48
|
}
|
|
64
49
|
override_value.min(tokenizer_profile.safe_max_length)
|
|
65
50
|
} else {
|
|
66
51
|
tokenizer_profile.default_max_length
|
|
67
52
|
};
|
|
68
|
-
let padding_mode =
|
|
69
|
-
parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
|
|
53
|
+
let padding_mode = parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
|
|
70
54
|
|
|
71
55
|
let session_config = ModelConfig {
|
|
72
56
|
max_length,
|
|
@@ -77,19 +61,18 @@ impl Embedder {
|
|
|
77
61
|
with_attention_mask: true,
|
|
78
62
|
optimization_level,
|
|
79
63
|
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
64
|
+
lowercase_input: overrides.lowercase_input.unwrap_or(false),
|
|
65
|
+
max_input_chars: overrides.max_input_chars,
|
|
80
66
|
};
|
|
81
67
|
let session = build_session(&model_path, &session_config)?;
|
|
82
68
|
|
|
83
69
|
validate_supported_text_inputs(&session, "text embedding")?;
|
|
84
70
|
let with_type_ids = has_input(&session, "token_type_ids");
|
|
85
71
|
let with_attention_mask = has_input(&session, "attention_mask");
|
|
86
|
-
let output_tensor =
|
|
87
|
-
select_output_tensor(&session, overrides.output_tensor, &PREFERRED_EMBEDDING_OUTPUTS)?;
|
|
72
|
+
let output_tensor = select_output_tensor(&session, overrides.output_tensor, &PREFERRED_EMBEDDING_OUTPUTS)?;
|
|
88
73
|
let mode = infer_extraction_mode(&session, output_tensor.as_str())?;
|
|
89
74
|
if matches!(mode, ExtractorMode::MeanPool) && !with_attention_mask {
|
|
90
|
-
return Err(GteError::Inference(
|
|
91
|
-
"cannot use mean pooling without attention_mask input".to_string(),
|
|
92
|
-
));
|
|
75
|
+
return Err(GteError::Inference("cannot use mean pooling without attention_mask input".to_string()));
|
|
93
76
|
}
|
|
94
77
|
|
|
95
78
|
let config = ModelConfig {
|
|
@@ -101,6 +84,8 @@ impl Embedder {
|
|
|
101
84
|
with_attention_mask,
|
|
102
85
|
optimization_level,
|
|
103
86
|
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
87
|
+
lowercase_input: overrides.lowercase_input.unwrap_or(false),
|
|
88
|
+
max_input_chars: overrides.max_input_chars,
|
|
104
89
|
};
|
|
105
90
|
|
|
106
91
|
let tokenizer = Tokenizer::new(
|
|
@@ -111,29 +96,72 @@ impl Embedder {
|
|
|
111
96
|
tokenizer_profile.fixed_padding_length,
|
|
112
97
|
)?;
|
|
113
98
|
|
|
114
|
-
let pool = SessionPool::new(session, model_path, session_config)
|
|
99
|
+
let pool = SessionPool::new(session, &model_path, &session_config)?;
|
|
115
100
|
Ok(Self { tokenizer, pool, config })
|
|
116
101
|
}
|
|
117
102
|
|
|
118
|
-
pub fn embed(&self, texts:
|
|
119
|
-
self.embed_ref(
|
|
103
|
+
pub fn embed(&self, texts: &[String]) -> Result<Array2<f32>> {
|
|
104
|
+
self.embed_ref(texts)
|
|
120
105
|
}
|
|
121
106
|
|
|
122
107
|
pub fn embed_ref(&self, texts: &[String]) -> Result<Array2<f32>> {
|
|
123
|
-
let
|
|
108
|
+
let sanitized: Vec<String>;
|
|
109
|
+
let input = if self.config.lowercase_input || self.config.max_input_chars.is_some() {
|
|
110
|
+
sanitized = texts
|
|
111
|
+
.iter()
|
|
112
|
+
.map(|t| {
|
|
113
|
+
let mut s = if self.config.lowercase_input { t.to_lowercase() } else { t.clone() };
|
|
114
|
+
if let Some(max_chars) = self.config.max_input_chars {
|
|
115
|
+
s.truncate(max_chars.min(s.len()));
|
|
116
|
+
}
|
|
117
|
+
s
|
|
118
|
+
})
|
|
119
|
+
.collect();
|
|
120
|
+
&sanitized
|
|
121
|
+
} else {
|
|
122
|
+
texts
|
|
123
|
+
};
|
|
124
|
+
let tokenized = self.tokenize(input)?;
|
|
124
125
|
self.run(&tokenized)
|
|
125
126
|
}
|
|
126
127
|
|
|
127
|
-
pub fn tokenize(&self, texts: &[String]) ->
|
|
128
|
+
pub fn tokenize(&self, texts: &[String]) -> Result<Tokenized> {
|
|
128
129
|
self.tokenizer.tokenize(texts)
|
|
129
130
|
}
|
|
130
131
|
|
|
131
|
-
pub fn run(&self, tokenized: &Tokenized) ->
|
|
132
|
-
|
|
133
|
-
run_session(&mut session, tokenized, &self.config)
|
|
132
|
+
pub fn run(&self, tokenized: &Tokenized) -> Result<Array2<f32>> {
|
|
133
|
+
self.pool.run(tokenized, &self.config)
|
|
134
134
|
}
|
|
135
135
|
}
|
|
136
136
|
|
|
137
137
|
pub fn normalize_l2(embeddings: Array2<f32>) -> Array2<f32> {
|
|
138
138
|
normalize_l2_rows(embeddings)
|
|
139
139
|
}
|
|
140
|
+
|
|
141
|
+
pub fn output_name_suggests_normalized(name: &str) -> bool {
|
|
142
|
+
let lower = name.to_ascii_lowercase();
|
|
143
|
+
let base = lower.rsplit('/').next().unwrap_or(&lower);
|
|
144
|
+
base.contains("normalized") || base.contains("l2_norm") || base.contains("l2norm")
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
#[cfg(test)]
|
|
148
|
+
mod normalize_tests {
|
|
149
|
+
use super::output_name_suggests_normalized;
|
|
150
|
+
|
|
151
|
+
#[test]
|
|
152
|
+
fn detects_normalized_output_names() {
|
|
153
|
+
assert!(output_name_suggests_normalized("pooled_sentence_embeddings_debiased_normalized"));
|
|
154
|
+
assert!(output_name_suggests_normalized("embeddings/L2_Normalized"));
|
|
155
|
+
assert!(output_name_suggests_normalized("l2norm_output"));
|
|
156
|
+
assert!(output_name_suggests_normalized("norm/l2_norm_tensor"));
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
#[test]
|
|
160
|
+
fn does_not_detect_raw_output_names() {
|
|
161
|
+
assert!(!output_name_suggests_normalized("last_hidden_state"));
|
|
162
|
+
assert!(!output_name_suggests_normalized("text_embeds"));
|
|
163
|
+
assert!(!output_name_suggests_normalized("pooler_output"));
|
|
164
|
+
assert!(!output_name_suggests_normalized("sentence_embedding"));
|
|
165
|
+
assert!(!output_name_suggests_normalized("logits"));
|
|
166
|
+
}
|
|
167
|
+
}
|
data/ext/gte/src/error.rs
CHANGED
|
@@ -9,10 +9,10 @@ pub enum GteError {
|
|
|
9
9
|
impl std::fmt::Display for GteError {
|
|
10
10
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
11
11
|
match self {
|
|
12
|
-
GteError::Tokenizer(msg) => write!(f, "GTE tokenizer error: {}"
|
|
13
|
-
GteError::Inference(msg) => write!(f, "GTE inference error: {}"
|
|
14
|
-
GteError::Ort(msg) => write!(f, "GTE ORT error: {}"
|
|
15
|
-
GteError::Shape(msg) => write!(f, "GTE shape error: {}"
|
|
12
|
+
GteError::Tokenizer(msg) => write!(f, "GTE tokenizer error: {msg}"),
|
|
13
|
+
GteError::Inference(msg) => write!(f, "GTE inference error: {msg}"),
|
|
14
|
+
GteError::Ort(msg) => write!(f, "GTE ORT error: {msg}"),
|
|
15
|
+
GteError::Shape(msg) => write!(f, "GTE shape error: {msg}"),
|
|
16
16
|
}
|
|
17
17
|
}
|
|
18
18
|
}
|
data/ext/gte/src/lib.rs
CHANGED
|
@@ -19,7 +19,7 @@ use magnus::{prelude::*, Error, Ruby};
|
|
|
19
19
|
fn init(ruby: &Ruby) -> Result<(), Error> {
|
|
20
20
|
let module = ruby.define_module("GTE")?;
|
|
21
21
|
module.define_error("Error", ruby.exception_standard_error())?;
|
|
22
|
-
|
|
22
|
+
ruby_embedder::register(ruby)?;
|
|
23
23
|
std::panic::set_hook(Box::new(|info| {
|
|
24
24
|
let msg = info
|
|
25
25
|
.payload()
|
data/ext/gte/src/model_config.rs
CHANGED
|
@@ -23,6 +23,8 @@ pub struct ModelConfig {
|
|
|
23
23
|
pub with_attention_mask: bool,
|
|
24
24
|
pub optimization_level: u8,
|
|
25
25
|
pub execution_providers: Option<String>,
|
|
26
|
+
pub lowercase_input: bool,
|
|
27
|
+
pub max_input_chars: Option<usize>,
|
|
26
28
|
}
|
|
27
29
|
|
|
28
30
|
#[derive(Debug, Clone, Copy, Default)]
|
|
@@ -32,4 +34,6 @@ pub struct ModelLoadOverrides<'a> {
|
|
|
32
34
|
pub max_length: Option<usize>,
|
|
33
35
|
pub padding: Option<&'a str>,
|
|
34
36
|
pub execution_providers: Option<&'a str>,
|
|
37
|
+
pub lowercase_input: Option<bool>,
|
|
38
|
+
pub max_input_chars: Option<usize>,
|
|
35
39
|
}
|