gte 0.0.15 → 0.0.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile +0 -1
- data/README.md +112 -82
- data/Rakefile +0 -9
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +1 -1
- data/ext/gte/src/embedder.rs +29 -65
- data/ext/gte/src/lib.rs +1 -0
- data/ext/gte/src/model_config.rs +0 -4
- data/ext/gte/src/pipeline.rs +8 -9
- data/ext/gte/src/postprocess.rs +8 -6
- data/ext/gte/src/reranker.rs +7 -10
- data/ext/gte/src/ruby_embedder.rs +10 -33
- data/ext/gte/src/session.rs +50 -156
- data/ext/gte/src/tokenizer.rs +45 -38
- data/ext/gte/tests/embedder_unit_test.rs +1 -1
- data/ext/gte/tests/padding_regression_test.rs +7 -25
- data/ext/gte/tests/tokenizer_unit_test.rs +7 -7
- data/lib/gte/config.rb +1 -2
- data/lib/gte/embedder.rb +2 -14
- data/lib/gte/model.rb +0 -7
- data/lib/gte/reranker.rb +14 -33
- data/lib/gte.rb +4 -25
- metadata +1 -1
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: e0d54266cdf80ae4259c8ff54623c41119123d7576f4b0b86491c76622af0e58
|
|
4
|
+
data.tar.gz: c72b372118445b8273efbc942d61d1dbeb3c23f75958829ac1029976782fcd05
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 060d1de598dd38868dfdd1f2dc81957d57544870050770256138e1db7776b76ae9c9bf480c7065dde2878b70199b86f8a6f89330513c8080db3a1df4046437c0
|
|
7
|
+
data.tar.gz: 11a131891cf664cdaabfee912eff25d0cd5c4a2197fcb41b46b53ba4d17a66b9a3b734dacecc4dbea4f6cb0ee433d861cf90f21bb9c3474f02afd5fb7d355a6f
|
data/Gemfile
CHANGED
data/README.md
CHANGED
|
@@ -15,32 +15,29 @@ model = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
|
|
|
15
15
|
tensor = model.embed("query: hello world")
|
|
16
16
|
vector = tensor.row(0)
|
|
17
17
|
|
|
18
|
-
#
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
# [] with array => GTE::Tensor (batch)
|
|
22
|
-
batch = model[["query: hello", "query: world"]]
|
|
18
|
+
# Binary f32 bytes (zero-copy to Numo/NumPy)
|
|
19
|
+
bytes = model.embed_binary("query: hello world")
|
|
23
20
|
```
|
|
24
21
|
|
|
25
|
-
## Embedding Config (`GTE
|
|
22
|
+
## Embedding Config (`GTE::Pool`)
|
|
26
23
|
|
|
27
|
-
`GTE.config(model_dir)`
|
|
24
|
+
`GTE.config(model_dir)` creates a new pool with one ONNX session by default.
|
|
28
25
|
|
|
29
26
|
```ruby
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
raw_model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
33
|
-
config.with(normalize: false)
|
|
34
|
-
end
|
|
27
|
+
default = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
|
|
28
|
+
default.embed("query: hello world")
|
|
35
29
|
|
|
36
|
-
|
|
30
|
+
# With config overrides
|
|
31
|
+
configurable = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
37
32
|
config.with(
|
|
38
33
|
output_tensor: "last_hidden_state",
|
|
39
|
-
max_length:
|
|
40
|
-
|
|
41
|
-
optimization_level: 3
|
|
34
|
+
max_length: 128,
|
|
35
|
+
execution_providers: "xnnpack"
|
|
42
36
|
)
|
|
43
37
|
end
|
|
38
|
+
|
|
39
|
+
# Explicit pool size (each session costs ~120MB RSS)
|
|
40
|
+
large = GTE.config(ENV.fetch("GTE_MODEL_DIR"), pool_size: 4)
|
|
44
41
|
```
|
|
45
42
|
|
|
46
43
|
Config fields and defaults:
|
|
@@ -48,19 +45,11 @@ Config fields and defaults:
|
|
|
48
45
|
- `model_dir`: absolute path to model directory
|
|
49
46
|
- `optimization_level`: `3`
|
|
50
47
|
- `model_name`: `nil`
|
|
51
|
-
- `normalize`: `true` (L2 normalization at Ruby-facing API)
|
|
52
48
|
- `output_tensor`: `nil` (auto-select output tensor)
|
|
53
49
|
- `max_length`: `nil` (uses tokenizer/model defaults)
|
|
54
50
|
- `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
|
|
55
51
|
- `execution_providers`: `nil` (falls back to `GTE_EXECUTION_PROVIDERS` / CPU default)
|
|
56
52
|
|
|
57
|
-
Notes:
|
|
58
|
-
|
|
59
|
-
- Return a `Config::Text` from the block (for example, `config.with(...)`).
|
|
60
|
-
- Model instances are cached by full config key; different config values create different cached instances.
|
|
61
|
-
- `GTE.warmup(model, threads:)` pre-warms thread-local ONNX sessions eagerly at boot.
|
|
62
|
-
Useful in multi-threaded servers (Puma, Sidekiq) to avoid ~100-500ms cold-start latency.
|
|
63
|
-
|
|
64
53
|
Common model presets:
|
|
65
54
|
|
|
66
55
|
```ruby
|
|
@@ -91,27 +80,28 @@ clip = GTE.config(ENV.fetch("GTE_CLIP_DIR")) do |config|
|
|
|
91
80
|
end
|
|
92
81
|
```
|
|
93
82
|
|
|
94
|
-
|
|
83
|
+
Output selection:
|
|
95
84
|
|
|
96
85
|
- Use `output_tensor:` to request a named model output.
|
|
97
86
|
- `last_hidden_state` gives token-level hidden states and is mean-pooled by `gte` when the tensor is rank 3.
|
|
98
|
-
- `pooler_output`, `sentence_embedding`, and similar 2D tensors are returned directly and
|
|
87
|
+
- `pooler_output`, `sentence_embedding`, and similar 2D tensors are returned directly and L2-normalized.
|
|
88
|
+
- If the output tensor name suggests already-normalized output (e.g. `l2_norm`, `normalized`), normalization is skipped.
|
|
99
89
|
- If the requested tensor is not present in the model, `gte` raises an error instead of silently falling back.
|
|
100
90
|
|
|
101
|
-
Low-level embedder setup (without
|
|
91
|
+
Low-level embedder setup (without Pool convenience):
|
|
102
92
|
|
|
103
93
|
```ruby
|
|
104
|
-
embedder = GTE::Embedder.
|
|
105
|
-
|
|
106
|
-
|
|
94
|
+
embedder = GTE::Embedder.from_config(
|
|
95
|
+
GTE::Embedder.default_config(ENV.fetch("GTE_MODEL_DIR"))
|
|
96
|
+
)
|
|
107
97
|
```
|
|
108
98
|
|
|
109
99
|
## Reranker
|
|
110
100
|
|
|
111
|
-
Use `GTE::Reranker.
|
|
101
|
+
Use `GTE::Reranker.new(model_dir)` for cross-encoder reranking.
|
|
112
102
|
|
|
113
103
|
```ruby
|
|
114
|
-
reranker = GTE::Reranker.
|
|
104
|
+
reranker = GTE::Reranker.new(ENV.fetch("GTE_RERANK_DIR")) do |config|
|
|
115
105
|
config.with(sigmoid: true)
|
|
116
106
|
end
|
|
117
107
|
|
|
@@ -124,13 +114,6 @@ candidates = [
|
|
|
124
114
|
# Raw scores aligned with input order
|
|
125
115
|
scores = reranker.score(query, candidates)
|
|
126
116
|
# => [0.93, 0.07]
|
|
127
|
-
|
|
128
|
-
# Ranked output sorted by score desc
|
|
129
|
-
ranked = reranker.rerank(query: query, candidates: candidates)
|
|
130
|
-
# => [
|
|
131
|
-
# { index: 0, score: 0.93, text: "Backpropagation and gradient descent are core techniques." },
|
|
132
|
-
# { index: 1, score: 0.07, text: "This recipe uses flour and eggs." }
|
|
133
|
-
# ]
|
|
134
117
|
```
|
|
135
118
|
|
|
136
119
|
Reranker config fields and defaults:
|
|
@@ -144,26 +127,10 @@ Reranker config fields and defaults:
|
|
|
144
127
|
- `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
|
|
145
128
|
- `execution_providers`: `nil`
|
|
146
129
|
|
|
147
|
-
Session pool sizing:
|
|
148
|
-
|
|
149
|
-
- `GTE_SESSION_POOL_CAP`: optional positive integer cap for internal ONNX session pool size.
|
|
150
|
-
- Unset by default; runtime uses available CPU parallelism.
|
|
151
|
-
|
|
152
130
|
## Automatic Tuning
|
|
153
131
|
|
|
154
132
|
`gte` automatically adapts to the hardware — no configuration required.
|
|
155
133
|
|
|
156
|
-
### ONNX Intra-op Threads
|
|
157
|
-
|
|
158
|
-
- Auto-detected via `std::thread::available_parallelism()` capped at 4.
|
|
159
|
-
- Prevents oversubscription on high-concurrency workloads.
|
|
160
|
-
- Override with `GTE_INTRA_OP_NUM_THREADS` env var.
|
|
161
|
-
|
|
162
|
-
### ONNX Inter-op Threads
|
|
163
|
-
|
|
164
|
-
- Defaults to 1 (text embedding graphs are linear chains with no independent parallel nodes).
|
|
165
|
-
- Override with `GTE_INTER_OP_NUM_THREADS` env var.
|
|
166
|
-
|
|
167
134
|
### Execution Providers
|
|
168
135
|
|
|
169
136
|
`gte` automatically tries XNNPACK for optimized CPU inference. Falls back to
|
|
@@ -183,28 +150,61 @@ export GTE_EXECUTION_PROVIDERS=xnnpack,coreml
|
|
|
183
150
|
|
|
184
151
|
Set `cpu` or `none` to skip auto-detect and use ORT's default CPU provider.
|
|
185
152
|
|
|
153
|
+
### Session Pool
|
|
154
|
+
|
|
155
|
+
gte uses a **pre-allocated session pool** per worker — it creates N sessions at
|
|
156
|
+
construction time, where N is determined by:
|
|
157
|
+
|
|
158
|
+
| Priority | Source | Description |
|
|
159
|
+
|----------|--------|-------------|
|
|
160
|
+
| 1 | `GTE_SESSION_POOL_SIZE` | Explicit size (e.g. `4`) |
|
|
161
|
+
| 2 | `PUMA_MAX_THREADS` | Match Puma concurrency (capped at 8) |
|
|
162
|
+
| 3 | Default | `1` (single session, matching the unsplash-api singleton pattern) |
|
|
163
|
+
|
|
164
|
+
The pool is fixed-size: sessions are never created or destroyed after construction.
|
|
165
|
+
When all sessions are busy, the calling thread blocks on `parking_lot::Mutex`
|
|
166
|
+
until a session is released. This avoids the allocation and memory overhead of
|
|
167
|
+
lazy-growing pools while matching the concurrency needs of application threads.
|
|
168
|
+
|
|
186
169
|
### Session Pre-Warming
|
|
187
170
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
171
|
+
The pool is pre-warmed automatically in `GTE.config` — one inference per
|
|
172
|
+
session is run on construction so the first production request never hits a cold
|
|
173
|
+
cache. No manual warmup step needed.
|
|
191
174
|
|
|
192
|
-
|
|
193
|
-
model = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
|
|
175
|
+
To re-warm (useful after fork in Puma's `on_worker_boot`):
|
|
194
176
|
|
|
195
|
-
|
|
196
|
-
|
|
177
|
+
```ruby
|
|
178
|
+
pool.warmup
|
|
197
179
|
```
|
|
198
180
|
|
|
199
|
-
|
|
181
|
+
### Tuning Performance
|
|
182
|
+
|
|
183
|
+
| Variable | Effect | Default |
|
|
184
|
+
|----------|--------|---------|
|
|
185
|
+
| `GTE_SESSION_POOL_SIZE` | Max ONNX sessions per worker | `1` (or `PUMA_MAX_THREADS`) |
|
|
186
|
+
| `GTE_INTRA_OP_NUM_THREADS` | Threads ONNX Runtime uses per inference op | `min(CPU cores, 4)` |
|
|
187
|
+
| `GTE_INTER_OP_NUM_THREADS` | Threads for independent graph nodes (irrelevant for text models) | `1` |
|
|
188
|
+
| `GTE_EXECUTION_PROVIDERS` | Comma-separated: `xnnpack`, `coreml`, `cpu` | Auto: `xnnpack` on arm64 |
|
|
189
|
+
|
|
190
|
+
**To squeeze more throughput:**
|
|
191
|
+
- Set `GTE_SESSION_POOL_SIZE` to match or slightly exceed your Puma `MAX_THREADS`.
|
|
192
|
+
- On machines with many cores, reduce `GTE_INTRA_OP_NUM_THREADS` to `1` or `2`
|
|
193
|
+
to avoid CPU oversubscription when multiple sessions run concurrently.
|
|
194
|
+
|
|
195
|
+
**Memory estimation per worker:**
|
|
196
|
+
- Pool size N (default 1): **N × model file size × 3–5**
|
|
197
|
+
- Each additional session adds ~120MB RSS on arm64 with XNNPACK.
|
|
198
|
+
|
|
199
|
+
## Runtime
|
|
200
200
|
|
|
201
201
|
Process-local reuse (recommended for Puma/web servers):
|
|
202
202
|
|
|
203
203
|
```ruby
|
|
204
|
-
|
|
204
|
+
$gte = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
|
|
205
205
|
|
|
206
206
|
def embed_query(text)
|
|
207
|
-
|
|
207
|
+
$gte.embed(text).row(0) # Array<Float>
|
|
208
208
|
end
|
|
209
209
|
```
|
|
210
210
|
|
|
@@ -219,7 +219,6 @@ A model directory must include `tokenizer.json` and one ONNX model, resolved in
|
|
|
219
219
|
|
|
220
220
|
Input policy is text-only. Graphs requiring unsupported multimodal inputs (such as `pixel_values`) are intentionally rejected.
|
|
221
221
|
|
|
222
|
-
|
|
223
222
|
## Development
|
|
224
223
|
|
|
225
224
|
Run commands inside `nix develop` via Make targets:
|
|
@@ -256,38 +255,37 @@ make bench-docker-validate # cross-validation checks
|
|
|
256
255
|
|
|
257
256
|
| Concurrency | GTE p90 | Pure Ruby p90 | Ratio | GTE RPS | Pure Ruby RPS |
|
|
258
257
|
|------------|---------|---------------|-------|---------|---------------|
|
|
259
|
-
| c=1 | ~
|
|
260
|
-
| c=
|
|
261
|
-
| c=
|
|
262
|
-
| c=
|
|
258
|
+
| c=1 | ~14ms | ~92ms | 6.4× | ~89 | ~21 |
|
|
259
|
+
| c=2 | ~15ms | ~175ms | 11.4× | ~163 | ~21 |
|
|
260
|
+
| c=4 | ~39ms | ~293ms | 7.4× | ~219 | ~24 |
|
|
261
|
+
| c=8 | ~75ms | ~502ms | 6.7× | ~195 | ~24 |
|
|
262
|
+
| c=16 | ~279ms | ~606ms | 2.2× | ~219 | ~26 |
|
|
263
263
|
|
|
264
264
|
#### E5 (384-dim, last_hidden_state + mean pool)
|
|
265
265
|
|
|
266
266
|
| Concurrency | GTE p90 | Pure Ruby p90 | Ratio | GTE RPS | Pure Ruby RPS |
|
|
267
267
|
|------------|---------|---------------|-------|---------|---------------|
|
|
268
|
-
| c=1 | ~
|
|
269
|
-
| c=
|
|
270
|
-
| c=
|
|
271
|
-
| c=
|
|
268
|
+
| c=1 | ~8ms | ~73ms | 9.3× | ~152 | ~32 |
|
|
269
|
+
| c=2 | ~8ms | ~95ms | 11.8× | ~291 | ~36 |
|
|
270
|
+
| c=4 | ~22ms | ~163ms | 7.5× | ~432 | ~45 |
|
|
271
|
+
| c=8 | ~51ms | ~291ms | 5.7× | ~451 | ~43 |
|
|
272
|
+
| c=16 | ~133ms | ~1080ms | 8.1× | ~467 | ~47 |
|
|
272
273
|
|
|
273
|
-
GTE releases the GVL during ONNX inference, enabling true parallelism across
|
|
274
|
-
|
|
274
|
+
GTE releases the GVL during ONNX inference, enabling true parallelism across
|
|
275
|
+
Puma threads and worker processes. Pure Ruby is serialized
|
|
276
|
+
(~25–45 RPS regardless of concurrency).
|
|
275
277
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
+
Config: Puma workers=2, threads=min=2/max=5, cpus=4, mem_limit=3g.
|
|
279
|
+
Docker wrk with random 135-text query set, 15s runs.
|
|
278
280
|
|
|
279
281
|
### In-Process Benchmarks
|
|
280
282
|
|
|
281
283
|
```bash
|
|
282
284
|
make bench
|
|
283
|
-
nix develop -c bundle exec rake bench:pure_compare
|
|
284
|
-
nix develop -c bundle exec rake bench:matrix_sweep
|
|
285
285
|
nix develop -c bundle exec ruby bench/memory_probe.rb --compare-pure
|
|
286
286
|
```
|
|
287
287
|
|
|
288
288
|
- `make bench`: Puma-like single-request comparison at concurrency `16`
|
|
289
|
-
- `rake bench:pure_compare`: batch amortization comparison
|
|
290
|
-
- `rake bench:matrix_sweep`: GTE provider sweep using the shared result schema
|
|
291
289
|
- Optional Python comparisons use `bench/python_onnxruntime.py` and are skipped automatically if local dependencies are unavailable.
|
|
292
290
|
|
|
293
291
|
To run benchmark + append a `RUNS.md` entry + enforce goal checks:
|
|
@@ -300,3 +298,35 @@ make bench-record
|
|
|
300
298
|
|
|
301
299
|
- Enforces the goal metric (`response_time_p95`) across every enabled competitor.
|
|
302
300
|
- Does not require current-version coverage in `RUNS.md` unless explicitly enabled.
|
|
301
|
+
|
|
302
|
+
## Fork Safety
|
|
303
|
+
|
|
304
|
+
GTE uses ONNX Runtime sessions which maintain internal thread pools for parallelism
|
|
305
|
+
(`GTE_INTRA_OP_NUM_THREADS`, default `min(cpus, 4)`). These thread pools are
|
|
306
|
+
per-session and may not survive `fork()` on some platforms.
|
|
307
|
+
|
|
308
|
+
**With Puma's `preload_app!`:**
|
|
309
|
+
|
|
310
|
+
Sessions built before `fork()` share memory via COW, but the internal ORT threads
|
|
311
|
+
created during `Session::builder().commit_from_file()` do not exist in the child
|
|
312
|
+
process. When a forked worker calls `session.run()`, ORT must recreate these
|
|
313
|
+
threads, which adds latency to the first inference call.
|
|
314
|
+
|
|
315
|
+
**Recommendations:**
|
|
316
|
+
|
|
317
|
+
1. **Set `GTE_INTRA_OP_NUM_THREADS=1`** in forked environments to avoid creating
|
|
318
|
+
per-session thread pools entirely. ORT will run inference single-threaded,
|
|
319
|
+
which is acceptable when multiple sessions handle concurrency.
|
|
320
|
+
2. **Build sessions in `on_worker_boot`** instead of before fork to guarantee
|
|
321
|
+
fresh thread pools in each worker. This adds ~200ms to worker startup per
|
|
322
|
+
model but ensures consistent inference latency:
|
|
323
|
+
|
|
324
|
+
```ruby
|
|
325
|
+
# config/puma.rb
|
|
326
|
+
on_worker_boot do
|
|
327
|
+
$gte_pool = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
|
|
328
|
+
end
|
|
329
|
+
```
|
|
330
|
+
|
|
331
|
+
3. **If using `preload_app!`**, call `GTE.config` in `before_fork` and set
|
|
332
|
+
`GTE_INTRA_OP_NUM_THREADS=1` to avoid thread pool issues in child processes.
|
data/Rakefile
CHANGED
|
@@ -74,15 +74,6 @@ namespace :bench do
|
|
|
74
74
|
)
|
|
75
75
|
end
|
|
76
76
|
|
|
77
|
-
desc 'Sweep execution-provider settings for Puma-like benchmark'
|
|
78
|
-
task :matrix_sweep do
|
|
79
|
-
run_in_nix(
|
|
80
|
-
'bundle', 'exec', 'ruby', 'bench/puma_matrix_sweep.rb',
|
|
81
|
-
'--iterations', '80',
|
|
82
|
-
'--runs', '3'
|
|
83
|
-
)
|
|
84
|
-
end
|
|
85
|
-
|
|
86
77
|
desc 'Run memory probe for single-instance vs duplicate-instance behavior'
|
|
87
78
|
task :memory_probe do
|
|
88
79
|
run_in_nix(
|
data/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.0.
|
|
1
|
+
0.0.16
|
data/ext/gte/Cargo.toml
CHANGED
data/ext/gte/src/embedder.rs
CHANGED
|
@@ -5,8 +5,8 @@ use crate::model_profile::{
|
|
|
5
5
|
resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
|
|
6
6
|
};
|
|
7
7
|
use crate::postprocess::normalize_l2 as normalize_l2_rows;
|
|
8
|
-
use crate::session::{build_session, SessionPool};
|
|
9
|
-
use crate::tokenizer::{parse_padding_mode_override,
|
|
8
|
+
use crate::session::{build_session, resolve_pool_size, run_session, SessionPool};
|
|
9
|
+
use crate::tokenizer::{parse_padding_mode_override, Tokenizer};
|
|
10
10
|
use ndarray::Array2;
|
|
11
11
|
use std::path::{Path, PathBuf};
|
|
12
12
|
|
|
@@ -14,22 +14,10 @@ pub struct Embedder {
|
|
|
14
14
|
tokenizer: Tokenizer,
|
|
15
15
|
pool: SessionPool,
|
|
16
16
|
pub config: ModelConfig,
|
|
17
|
+
normalize: bool,
|
|
17
18
|
}
|
|
18
19
|
|
|
19
20
|
impl Embedder {
|
|
20
|
-
pub fn new<P1, P2>(tokenizer_path: P1, model_path: P2, config: ModelConfig) -> Result<Self>
|
|
21
|
-
where
|
|
22
|
-
P1: AsRef<Path>,
|
|
23
|
-
P2: AsRef<Path>,
|
|
24
|
-
{
|
|
25
|
-
let tokenizer =
|
|
26
|
-
Tokenizer::new(tokenizer_path, config.max_length, config.with_type_ids, config.padding_mode, None)?;
|
|
27
|
-
let model_path = model_path.as_ref();
|
|
28
|
-
let session = build_session(model_path, &config)?;
|
|
29
|
-
let pool = SessionPool::new(session, model_path, &config)?;
|
|
30
|
-
Ok(Self { tokenizer, pool, config })
|
|
31
|
-
}
|
|
32
|
-
|
|
33
21
|
pub fn from_dir<P: AsRef<Path>>(dir: P, optimization_level: u8, overrides: ModelLoadOverrides<'_>) -> Result<Self> {
|
|
34
22
|
const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] =
|
|
35
23
|
["pooler_output", "text_embeds", "sentence_embedding", "last_hidden_state"];
|
|
@@ -52,7 +40,7 @@ impl Embedder {
|
|
|
52
40
|
};
|
|
53
41
|
let padding_mode = parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
|
|
54
42
|
|
|
55
|
-
let
|
|
43
|
+
let probe_config = ModelConfig {
|
|
56
44
|
max_length,
|
|
57
45
|
padding_mode,
|
|
58
46
|
output_tensor: String::new(),
|
|
@@ -61,10 +49,8 @@ impl Embedder {
|
|
|
61
49
|
with_attention_mask: true,
|
|
62
50
|
optimization_level,
|
|
63
51
|
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
64
|
-
lowercase_input: overrides.lowercase_input.unwrap_or(false),
|
|
65
|
-
max_input_chars: overrides.max_input_chars,
|
|
66
52
|
};
|
|
67
|
-
let session = build_session(&model_path, &
|
|
53
|
+
let session = build_session(&model_path, &probe_config)?;
|
|
68
54
|
|
|
69
55
|
validate_supported_text_inputs(&session, "text embedding")?;
|
|
70
56
|
let with_type_ids = has_input(&session, "token_type_ids");
|
|
@@ -84,10 +70,10 @@ impl Embedder {
|
|
|
84
70
|
with_attention_mask,
|
|
85
71
|
optimization_level,
|
|
86
72
|
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
87
|
-
lowercase_input: overrides.lowercase_input.unwrap_or(false),
|
|
88
|
-
max_input_chars: overrides.max_input_chars,
|
|
89
73
|
};
|
|
90
74
|
|
|
75
|
+
let normalize = should_normalize_output(&config.output_tensor);
|
|
76
|
+
|
|
91
77
|
let tokenizer = Tokenizer::new(
|
|
92
78
|
&tokenizer_path,
|
|
93
79
|
config.max_length,
|
|
@@ -96,72 +82,50 @@ impl Embedder {
|
|
|
96
82
|
tokenizer_profile.fixed_padding_length,
|
|
97
83
|
)?;
|
|
98
84
|
|
|
99
|
-
let
|
|
100
|
-
|
|
85
|
+
let pool_size = resolve_pool_size();
|
|
86
|
+
let pool = SessionPool::new(&model_path, &config, pool_size)?;
|
|
87
|
+
Ok(Self { tokenizer, pool, config, normalize })
|
|
101
88
|
}
|
|
102
89
|
|
|
103
90
|
pub fn embed(&self, texts: &[String]) -> Result<Array2<f32>> {
|
|
104
|
-
self.
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
let sanitized: Vec<String>;
|
|
109
|
-
let input = if self.config.lowercase_input || self.config.max_input_chars.is_some() {
|
|
110
|
-
sanitized = texts
|
|
111
|
-
.iter()
|
|
112
|
-
.map(|t| {
|
|
113
|
-
let mut s = if self.config.lowercase_input { t.to_lowercase() } else { t.clone() };
|
|
114
|
-
if let Some(max_chars) = self.config.max_input_chars {
|
|
115
|
-
s.truncate(max_chars.min(s.len()));
|
|
116
|
-
}
|
|
117
|
-
s
|
|
118
|
-
})
|
|
119
|
-
.collect();
|
|
120
|
-
&sanitized
|
|
91
|
+
let tokenized = self.tokenizer.tokenize(texts)?;
|
|
92
|
+
let embeddings = self.pool.with_session(|session| run_session(session, &tokenized, &self.config))?;
|
|
93
|
+
if self.normalize {
|
|
94
|
+
Ok(normalize_l2_rows(embeddings))
|
|
121
95
|
} else {
|
|
122
|
-
|
|
123
|
-
}
|
|
124
|
-
let tokenized = self.tokenize(input)?;
|
|
125
|
-
self.run(&tokenized)
|
|
96
|
+
Ok(embeddings)
|
|
97
|
+
}
|
|
126
98
|
}
|
|
127
99
|
|
|
128
|
-
pub fn tokenize(&self, texts: &[String]) -> Result<Tokenized> {
|
|
100
|
+
pub fn tokenize(&self, texts: &[String]) -> Result<crate::tokenizer::Tokenized> {
|
|
129
101
|
self.tokenizer.tokenize(texts)
|
|
130
102
|
}
|
|
131
|
-
|
|
132
|
-
pub fn run(&self, tokenized: &Tokenized) -> Result<Array2<f32>> {
|
|
133
|
-
self.pool.run(tokenized, &self.config)
|
|
134
|
-
}
|
|
135
|
-
}
|
|
136
|
-
|
|
137
|
-
pub fn normalize_l2(embeddings: Array2<f32>) -> Array2<f32> {
|
|
138
|
-
normalize_l2_rows(embeddings)
|
|
139
103
|
}
|
|
140
104
|
|
|
141
|
-
|
|
105
|
+
fn should_normalize_output(name: &str) -> bool {
|
|
142
106
|
let lower = name.to_ascii_lowercase();
|
|
143
107
|
let base = lower.rsplit('/').next().unwrap_or(&lower);
|
|
144
|
-
base.contains("normalized") || base.contains("l2_norm") || base.contains("l2norm")
|
|
108
|
+
!(base.contains("normalized") || base.contains("l2_norm") || base.contains("l2norm"))
|
|
145
109
|
}
|
|
146
110
|
|
|
147
111
|
#[cfg(test)]
|
|
148
112
|
mod normalize_tests {
|
|
149
|
-
use super::
|
|
113
|
+
use super::should_normalize_output;
|
|
150
114
|
|
|
151
115
|
#[test]
|
|
152
116
|
fn detects_normalized_output_names() {
|
|
153
|
-
assert!(
|
|
154
|
-
assert!(
|
|
155
|
-
assert!(
|
|
156
|
-
assert!(
|
|
117
|
+
assert!(!should_normalize_output("pooled_sentence_embeddings_debiased_normalized"));
|
|
118
|
+
assert!(!should_normalize_output("embeddings/L2_Normalized"));
|
|
119
|
+
assert!(!should_normalize_output("l2norm_output"));
|
|
120
|
+
assert!(!should_normalize_output("norm/l2_norm_tensor"));
|
|
157
121
|
}
|
|
158
122
|
|
|
159
123
|
#[test]
|
|
160
124
|
fn does_not_detect_raw_output_names() {
|
|
161
|
-
assert!(
|
|
162
|
-
assert!(
|
|
163
|
-
assert!(
|
|
164
|
-
assert!(
|
|
165
|
-
assert!(
|
|
125
|
+
assert!(should_normalize_output("last_hidden_state"));
|
|
126
|
+
assert!(should_normalize_output("text_embeds"));
|
|
127
|
+
assert!(should_normalize_output("pooler_output"));
|
|
128
|
+
assert!(should_normalize_output("sentence_embedding"));
|
|
129
|
+
assert!(should_normalize_output("logits"));
|
|
166
130
|
}
|
|
167
131
|
}
|
data/ext/gte/src/lib.rs
CHANGED
|
@@ -18,6 +18,7 @@ use magnus::{prelude::*, Error, Ruby};
|
|
|
18
18
|
#[magnus::init]
|
|
19
19
|
fn init(ruby: &Ruby) -> Result<(), Error> {
|
|
20
20
|
let module = ruby.define_module("GTE")?;
|
|
21
|
+
#[allow(unused_results)]
|
|
21
22
|
module.define_error("Error", ruby.exception_standard_error())?;
|
|
22
23
|
ruby_embedder::register(ruby)?;
|
|
23
24
|
std::panic::set_hook(Box::new(|info| {
|
data/ext/gte/src/model_config.rs
CHANGED
|
@@ -23,8 +23,6 @@ pub struct ModelConfig {
|
|
|
23
23
|
pub with_attention_mask: bool,
|
|
24
24
|
pub optimization_level: u8,
|
|
25
25
|
pub execution_providers: Option<String>,
|
|
26
|
-
pub lowercase_input: bool,
|
|
27
|
-
pub max_input_chars: Option<usize>,
|
|
28
26
|
}
|
|
29
27
|
|
|
30
28
|
#[derive(Debug, Clone, Copy, Default)]
|
|
@@ -34,6 +32,4 @@ pub struct ModelLoadOverrides<'a> {
|
|
|
34
32
|
pub max_length: Option<usize>,
|
|
35
33
|
pub padding: Option<&'a str>,
|
|
36
34
|
pub execution_providers: Option<&'a str>,
|
|
37
|
-
pub lowercase_input: Option<bool>,
|
|
38
|
-
pub max_input_chars: Option<usize>,
|
|
39
35
|
}
|
data/ext/gte/src/pipeline.rs
CHANGED
|
@@ -11,21 +11,20 @@ pub struct InputTensors<'a> {
|
|
|
11
11
|
|
|
12
12
|
impl<'a> InputTensors<'a> {
|
|
13
13
|
pub fn from_tokenized(tokenized: &'a Tokenized, with_attention_mask: bool) -> Result<Self> {
|
|
14
|
-
let input_ids_view
|
|
15
|
-
|
|
16
|
-
let attention_mask: ArrayView2<'_, i64> =
|
|
17
|
-
ArrayView2::from_shape((tokenized.rows, tokenized.cols), tokenized.attn_masks.as_slice())?;
|
|
14
|
+
let input_ids_view = tokenized.input_ids.view();
|
|
15
|
+
let attention_mask = tokenized.attn_masks.view();
|
|
18
16
|
|
|
19
|
-
let mut inputs = Vec::with_capacity(2
|
|
20
|
-
inputs.push(("input_ids", SessionInputValue::from(TensorRef::from_array_view(input_ids_view)?)));
|
|
17
|
+
let mut inputs = Vec::with_capacity(2);
|
|
21
18
|
|
|
22
19
|
if with_attention_mask {
|
|
20
|
+
inputs.push(("input_ids", SessionInputValue::from(TensorRef::from_array_view(input_ids_view)?)));
|
|
23
21
|
inputs.push(("attention_mask", SessionInputValue::from(TensorRef::from_array_view(attention_mask)?)));
|
|
22
|
+
} else {
|
|
23
|
+
inputs.push(("input_ids", SessionInputValue::from(TensorRef::from_array_view(input_ids_view)?)));
|
|
24
24
|
}
|
|
25
25
|
|
|
26
|
-
if let Some(type_ids) = tokenized.type_ids
|
|
27
|
-
let type_ids_view
|
|
28
|
-
ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
|
|
26
|
+
if let Some(ref type_ids) = tokenized.type_ids {
|
|
27
|
+
let type_ids_view = type_ids.view();
|
|
29
28
|
inputs.push(("token_type_ids", SessionInputValue::from(TensorRef::from_array_view(type_ids_view)?)));
|
|
30
29
|
}
|
|
31
30
|
|
data/ext/gte/src/postprocess.rs
CHANGED
|
@@ -94,12 +94,13 @@ fn mean_pool_contiguous(
|
|
|
94
94
|
if mask_row.iter().all(|&weight| weight == 1) {
|
|
95
95
|
for token_index in 0..seq {
|
|
96
96
|
let token_base = hidden_base + token_index * dim;
|
|
97
|
-
|
|
98
|
-
|
|
97
|
+
let token_slice = &hidden[token_base..token_base + dim];
|
|
98
|
+
for (out, &h) in output_row.iter_mut().zip(token_slice.iter()) {
|
|
99
|
+
*out += h;
|
|
99
100
|
}
|
|
100
101
|
}
|
|
101
102
|
|
|
102
|
-
for value in output_row {
|
|
103
|
+
for value in output_row.iter_mut() {
|
|
103
104
|
*value *= seq_inverse;
|
|
104
105
|
}
|
|
105
106
|
continue;
|
|
@@ -114,15 +115,16 @@ fn mean_pool_contiguous(
|
|
|
114
115
|
|
|
115
116
|
let weight = weight_raw as f32;
|
|
116
117
|
let token_base = hidden_base + token_index * dim;
|
|
117
|
-
|
|
118
|
-
|
|
118
|
+
let token_slice = &hidden[token_base..token_base + dim];
|
|
119
|
+
for (out, &h) in output_row.iter_mut().zip(token_slice.iter()) {
|
|
120
|
+
*out += h * weight;
|
|
119
121
|
}
|
|
120
122
|
weight_sum += weight;
|
|
121
123
|
}
|
|
122
124
|
|
|
123
125
|
if weight_sum > 0.0 {
|
|
124
126
|
let inverse = weight_sum.recip();
|
|
125
|
-
for value in output_row {
|
|
127
|
+
for value in output_row.iter_mut() {
|
|
126
128
|
*value *= inverse;
|
|
127
129
|
}
|
|
128
130
|
}
|