gte 0.0.14-x86_64-linux → 0.0.16-x86_64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 89f684dfc37cc272603c6ba0cccadff558d743eb1a3d2b25ba5bc6eafe1efbeb
4
- data.tar.gz: d819b6b3523bf19ff84f0b7f5203213b470ed7e83d6ed3dcf2713daaf42b9a32
3
+ metadata.gz: 20d9efdde7d7af021cc2940f5f1a2c6139f05530aa9504dbd94f6d3470fd0ff2
4
+ data.tar.gz: dd999664494476009790aa7c1431abf900eefae2d846818e22bda8735712ee7d
5
5
  SHA512:
6
- metadata.gz: 153a2b6d1d8bdff7414ffa768f1ab0084f6a49578fa63e933a07f51eda21d8c4d10ae2be11ce13af314655488e2df060dad829446b19020a7b7d0d153e72ddc7
7
- data.tar.gz: 899ad238106610f95cb76abf3e48ae5625cc3098a6477f168db8e29c29d22f1af2683acf4fa105e1a66ce7c9b0465564f836319aab9da1d966aaaf67e2a994e6
6
+ metadata.gz: 60a197aed55cde07447227d011c95dd835ee150bd0e2d16319d434367da9dd5f5ef54ccfe09e509547ad13ec92c0919360125e743a174910e8d1b309969889f8
7
+ data.tar.gz: c1e5f83a48ebfb0e8fcb43d02bf2c744daa959386c5499cd489d1e220fbdf9b7c2a5baf2a48c51f3495db648a5b20c5f7888647dd78d497d8b2abcece2bcf890
data/Gemfile CHANGED
@@ -8,7 +8,6 @@ gem 'rake'
8
8
  gem 'rake-compiler'
9
9
  gem 'rb_sys'
10
10
  gem 'rspec'
11
- gem 'rspec-benchmark'
12
11
  gem 'rubocop', require: false
13
12
 
14
13
  group :bench do
data/README.md CHANGED
@@ -15,32 +15,29 @@ model = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
15
15
  tensor = model.embed("query: hello world")
16
16
  vector = tensor.row(0)
17
17
 
18
- # [] with string => Array<Float> (single vector)
19
- single = model["query: nearest coffee shop"]
20
-
21
- # [] with array => GTE::Tensor (batch)
22
- batch = model[["query: hello", "query: world"]]
18
+ # Binary f32 bytes (zero-copy to Numo/NumPy)
19
+ bytes = model.embed_binary("query: hello world")
23
20
  ```
24
21
 
25
- ## Embedding Config (`GTE.config`)
22
+ ## Embedding Config (`GTE::Pool`)
26
23
 
27
- `GTE.config(model_dir)` builds (and caches) a `GTE::Model`.
24
+ `GTE.config(model_dir)` creates a new pool with one ONNX session by default.
28
25
 
29
26
  ```ruby
30
- default_model = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
31
-
32
- raw_model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
33
- config.with(normalize: false)
34
- end
27
+ default = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
28
+ default.embed("query: hello world")
35
29
 
36
- custom = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
30
+ # With config overrides
31
+ configurable = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
37
32
  config.with(
38
33
  output_tensor: "last_hidden_state",
39
- max_length: 256,
40
- padding: "batch_longest",
41
- optimization_level: 3
34
+ max_length: 128,
35
+ execution_providers: "xnnpack"
42
36
  )
43
37
  end
38
+
39
+ # Explicit pool size (each session costs ~120MB RSS)
40
+ large = GTE.config(ENV.fetch("GTE_MODEL_DIR"), pool_size: 4)
44
41
  ```
45
42
 
46
43
  Config fields and defaults:
@@ -48,19 +45,11 @@ Config fields and defaults:
48
45
  - `model_dir`: absolute path to model directory
49
46
  - `optimization_level`: `3`
50
47
  - `model_name`: `nil`
51
- - `normalize`: `true` (L2 normalization at Ruby-facing API)
52
48
  - `output_tensor`: `nil` (auto-select output tensor)
53
49
  - `max_length`: `nil` (uses tokenizer/model defaults)
54
50
  - `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
55
51
  - `execution_providers`: `nil` (falls back to `GTE_EXECUTION_PROVIDERS` / CPU default)
56
52
 
57
- Notes:
58
-
59
- - Return a `Config::Text` from the block (for example, `config.with(...)`).
60
- - Model instances are cached by full config key; different config values create different cached instances.
61
- - `GTE.warmup(model, threads:)` pre-warms thread-local ONNX sessions eagerly at boot.
62
- Useful in multi-threaded servers (Puma, Sidekiq) to avoid ~100-500ms cold-start latency.
63
-
64
53
  Common model presets:
65
54
 
66
55
  ```ruby
@@ -91,27 +80,28 @@ clip = GTE.config(ENV.fetch("GTE_CLIP_DIR")) do |config|
91
80
  end
92
81
  ```
93
82
 
94
- Picking a specific layer:
83
+ Output selection:
95
84
 
96
85
  - Use `output_tensor:` to request a named model output.
97
86
  - `last_hidden_state` gives token-level hidden states and is mean-pooled by `gte` when the tensor is rank 3.
98
- - `pooler_output`, `sentence_embedding`, and similar 2D tensors are returned directly and then L2-normalized by default.
87
+ - `pooler_output`, `sentence_embedding`, and similar 2D tensors are returned directly and L2-normalized.
88
+ - If the output tensor name suggests already-normalized output (e.g. `l2_norm`, `normalized`), normalization is skipped.
99
89
  - If the requested tensor is not present in the model, `gte` raises an error instead of silently falling back.
100
90
 
101
- Low-level embedder setup (without model cache):
91
+ Low-level embedder setup (without Pool convenience):
102
92
 
103
93
  ```ruby
104
- embedder = GTE::Embedder.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
105
- config.with(execution_providers: "cpu")
106
- end
94
+ embedder = GTE::Embedder.from_config(
95
+ GTE::Embedder.default_config(ENV.fetch("GTE_MODEL_DIR"))
96
+ )
107
97
  ```
108
98
 
109
99
  ## Reranker
110
100
 
111
- Use `GTE::Reranker.config(model_dir)` for cross-encoder reranking.
101
+ Use `GTE::Reranker.new(model_dir)` for cross-encoder reranking.
112
102
 
113
103
  ```ruby
114
- reranker = GTE::Reranker.config(ENV.fetch("GTE_RERANK_DIR")) do |config|
104
+ reranker = GTE::Reranker.new(ENV.fetch("GTE_RERANK_DIR")) do |config|
115
105
  config.with(sigmoid: true)
116
106
  end
117
107
 
@@ -124,13 +114,6 @@ candidates = [
124
114
  # Raw scores aligned with input order
125
115
  scores = reranker.score(query, candidates)
126
116
  # => [0.93, 0.07]
127
-
128
- # Ranked output sorted by score desc
129
- ranked = reranker.rerank(query: query, candidates: candidates)
130
- # => [
131
- # { index: 0, score: 0.93, text: "Backpropagation and gradient descent are core techniques." },
132
- # { index: 1, score: 0.07, text: "This recipe uses flour and eggs." }
133
- # ]
134
117
  ```
135
118
 
136
119
  Reranker config fields and defaults:
@@ -144,26 +127,10 @@ Reranker config fields and defaults:
144
127
  - `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
145
128
  - `execution_providers`: `nil`
146
129
 
147
- Session pool sizing:
148
-
149
- - `GTE_SESSION_POOL_CAP`: optional positive integer cap for internal ONNX session pool size.
150
- - Unset by default; runtime uses available CPU parallelism.
151
-
152
130
  ## Automatic Tuning
153
131
 
154
132
  `gte` automatically adapts to the hardware — no configuration required.
155
133
 
156
- ### ONNX Intra-op Threads
157
-
158
- - Auto-detected via `std::thread::available_parallelism()` capped at 4.
159
- - Prevents oversubscription on high-concurrency workloads.
160
- - Override with `GTE_INTRA_OP_NUM_THREADS` env var.
161
-
162
- ### ONNX Inter-op Threads
163
-
164
- - Defaults to 1 (text embedding graphs are linear chains with no independent parallel nodes).
165
- - Override with `GTE_INTER_OP_NUM_THREADS` env var.
166
-
167
134
  ### Execution Providers
168
135
 
169
136
  `gte` automatically tries XNNPACK for optimized CPU inference. Falls back to
@@ -183,28 +150,61 @@ export GTE_EXECUTION_PROVIDERS=xnnpack,coreml
183
150
 
184
151
  Set `cpu` or `none` to skip auto-detect and use ORT's default CPU provider.
185
152
 
153
+ ### Session Pool
154
+
155
+ gte uses a **pre-allocated session pool** per worker — it creates N sessions at
156
+ construction time, where N is determined by:
157
+
158
+ | Priority | Source | Description |
159
+ |----------|--------|-------------|
160
+ | 1 | `GTE_SESSION_POOL_SIZE` | Explicit size (e.g. `4`) |
161
+ | 2 | `PUMA_MAX_THREADS` | Match Puma concurrency (capped at 8) |
162
+ | 3 | Default | `1` (single session, matching the unsplash-api singleton pattern) |
163
+
164
+ The pool is fixed-size: sessions are never created or destroyed after construction.
165
+ When all sessions are busy, the calling thread blocks on `parking_lot::Mutex`
166
+ until a session is released. This avoids the allocation and memory overhead of
167
+ lazy-growing pools while matching the concurrency needs of application threads.
168
+
186
169
  ### Session Pre-Warming
187
170
 
188
- ONNX sessions are created lazily per OS thread. In multi-threaded servers (Puma, Sidekiq),
189
- each thread creates its own session on first use (~100-500ms cold start).
190
- Pre-warm sessions eagerly at boot:
171
+ The pool is pre-warmed automatically in `GTE.config` one inference per
172
+ session is run on construction so the first production request never hits a cold
173
+ cache. No manual warmup step needed.
191
174
 
192
- ```ruby
193
- model = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
175
+ To re-warm (useful after fork in Puma's `on_worker_boot`):
194
176
 
195
- # Pre-warm thread-local sessions for a Puma server with 5 threads:
196
- GTE.warmup(model, threads: 5)
177
+ ```ruby
178
+ pool.warmup
197
179
  ```
198
180
 
199
- ## Runtime + Result Examples
181
+ ### Tuning Performance
182
+
183
+ | Variable | Effect | Default |
184
+ |----------|--------|---------|
185
+ | `GTE_SESSION_POOL_SIZE` | Max ONNX sessions per worker | `1` (or `PUMA_MAX_THREADS`) |
186
+ | `GTE_INTRA_OP_NUM_THREADS` | Threads ONNX Runtime uses per inference op | `min(CPU cores, 4)` |
187
+ | `GTE_INTER_OP_NUM_THREADS` | Threads for independent graph nodes (irrelevant for text models) | `1` |
188
+ | `GTE_EXECUTION_PROVIDERS` | Comma-separated: `xnnpack`, `coreml`, `cpu` | Auto: `xnnpack` on arm64 |
189
+
190
+ **To squeeze more throughput:**
191
+ - Set `GTE_SESSION_POOL_SIZE` to match or slightly exceed your Puma `MAX_THREADS`.
192
+ - On machines with many cores, reduce `GTE_INTRA_OP_NUM_THREADS` to `1` or `2`
193
+ to avoid CPU oversubscription when multiple sessions run concurrently.
194
+
195
+ **Memory estimation per worker:**
196
+ - Pool size N (default 1): **N × model file size × 3–5**
197
+ - Each additional session adds ~120MB RSS on arm64 with XNNPACK.
198
+
199
+ ## Runtime
200
200
 
201
201
  Process-local reuse (recommended for Puma/web servers):
202
202
 
203
203
  ```ruby
204
- EMBEDDER = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
204
+ $gte = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
205
205
 
206
206
  def embed_query(text)
207
- EMBEDDER[text] # Array<Float>
207
+ $gte.embed(text).row(0) # Array<Float>
208
208
  end
209
209
  ```
210
210
 
@@ -219,7 +219,6 @@ A model directory must include `tokenizer.json` and one ONNX model, resolved in
219
219
 
220
220
  Input policy is text-only. Graphs requiring unsupported multimodal inputs (such as `pixel_values`) are intentionally rejected.
221
221
 
222
-
223
222
  ## Development
224
223
 
225
224
  Run commands inside `nix develop` via Make targets:
@@ -256,38 +255,37 @@ make bench-docker-validate # cross-validation checks
256
255
 
257
256
  | Concurrency | GTE p90 | Pure Ruby p90 | Ratio | GTE RPS | Pure Ruby RPS |
258
257
  |------------|---------|---------------|-------|---------|---------------|
259
- | c=1 | ~12ms | ~120ms | 9-10× | ~95 | ~10 |
260
- | c=4 | ~39ms | ~503ms | 10-13× | ~228 | ~10 |
261
- | c=8 | ~146ms | ~613ms | 3-4× | ~224 | ~10 |
262
- | c=16 | ~430ms | ~611ms | 1-1.5× | ~226 | ~11 |
258
+ | c=1 | ~14ms | ~92ms | 6.4× | ~89 | ~21 |
259
+ | c=2 | ~15ms | ~175ms | 11.4× | ~163 | ~21 |
260
+ | c=4 | ~39ms | ~293ms | 7.4× | ~219 | ~24 |
261
+ | c=8 | ~75ms | ~502ms | 6.7× | ~195 | ~24 |
262
+ | c=16 | ~279ms | ~606ms | 2.2× | ~219 | ~26 |
263
263
 
264
264
  #### E5 (384-dim, last_hidden_state + mean pool)
265
265
 
266
266
  | Concurrency | GTE p90 | Pure Ruby p90 | Ratio | GTE RPS | Pure Ruby RPS |
267
267
  |------------|---------|---------------|-------|---------|---------------|
268
- | c=1 | ~7ms | ~120ms | 16-17× | ~160 | ~10 |
269
- | c=4 | ~12ms | ~430ms | 35-40× | ~477 | ~10 |
270
- | c=8 | ~64ms | ~530ms | 8-9× | ~503 | ~10 |
271
- | c=16 | ~205ms | ~534ms | 2-3× | ~509 | ~11 |
268
+ | c=1 | ~8ms | ~73ms | 9.3× | ~152 | ~32 |
269
+ | c=2 | ~8ms | ~95ms | 11.8× | ~291 | ~36 |
270
+ | c=4 | ~22ms | ~163ms | 7.5× | ~432 | ~45 |
271
+ | c=8 | ~51ms | ~291ms | 5.7× | ~451 | ~43 |
272
+ | c=16 | ~133ms | ~1080ms | 8.1× | ~467 | ~47 |
272
273
 
273
- GTE releases the GVL during ONNX inference, enabling true parallelism across Puma threads.
274
- Pure Ruby is GVL-bound (~10 RPS regardless of concurrency).
274
+ GTE releases the GVL during ONNX inference, enabling true parallelism across
275
+ Puma threads and worker processes. Pure Ruby is serialized
276
+ (~25–45 RPS regardless of concurrency).
275
277
 
276
- The Puma thread pool (min=2, max=5) limits throughput at c=16+.
277
- GTE's pipelining and GVL release already saturate the available threads at c=4.
278
+ Config: Puma workers=2, threads=min=2/max=5, cpus=4, mem_limit=3g.
279
+ Docker wrk with random 135-text query set, 15s runs.
278
280
 
279
281
  ### In-Process Benchmarks
280
282
 
281
283
  ```bash
282
284
  make bench
283
- nix develop -c bundle exec rake bench:pure_compare
284
- nix develop -c bundle exec rake bench:matrix_sweep
285
285
  nix develop -c bundle exec ruby bench/memory_probe.rb --compare-pure
286
286
  ```
287
287
 
288
288
  - `make bench`: Puma-like single-request comparison at concurrency `16`
289
- - `rake bench:pure_compare`: batch amortization comparison
290
- - `rake bench:matrix_sweep`: GTE provider sweep using the shared result schema
291
289
  - Optional Python comparisons use `bench/python_onnxruntime.py` and are skipped automatically if local dependencies are unavailable.
292
290
 
293
291
  To run benchmark + append a `RUNS.md` entry + enforce goal checks:
@@ -300,3 +298,35 @@ make bench-record
300
298
 
301
299
  - Enforces the goal metric (`response_time_p95`) across every enabled competitor.
302
300
  - Does not require current-version coverage in `RUNS.md` unless explicitly enabled.
301
+
302
+ ## Fork Safety
303
+
304
+ GTE uses ONNX Runtime sessions which maintain internal thread pools for parallelism
305
+ (`GTE_INTRA_OP_NUM_THREADS`, default `min(cpus, 4)`). These thread pools are
306
+ per-session and may not survive `fork()` on some platforms.
307
+
308
+ **With Puma's `preload_app!`:**
309
+
310
+ Sessions built before `fork()` share memory via COW, but the internal ORT threads
311
+ created during `Session::builder().commit_from_file()` do not exist in the child
312
+ process. When a forked worker calls `session.run()`, ORT must recreate these
313
+ threads, which adds latency to the first inference call.
314
+
315
+ **Recommendations:**
316
+
317
+ 1. **Set `GTE_INTRA_OP_NUM_THREADS=1`** in forked environments to avoid creating
318
+ per-session thread pools entirely. ORT will run inference single-threaded,
319
+ which is acceptable when multiple sessions handle concurrency.
320
+ 2. **Build sessions in `on_worker_boot`** instead of before fork to guarantee
321
+ fresh thread pools in each worker. This adds ~200ms to worker startup per
322
+ model but ensures consistent inference latency:
323
+
324
+ ```ruby
325
+ # config/puma.rb
326
+ on_worker_boot do
327
+ $gte_pool = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
328
+ end
329
+ ```
330
+
331
+ 3. **If using `preload_app!`**, call `GTE.config` in `before_fork` and set
332
+ `GTE_INTRA_OP_NUM_THREADS=1` to avoid thread pool issues in child processes.
data/Rakefile CHANGED
@@ -74,15 +74,6 @@ namespace :bench do
74
74
  )
75
75
  end
76
76
 
77
- desc 'Sweep execution-provider settings for Puma-like benchmark'
78
- task :matrix_sweep do
79
- run_in_nix(
80
- 'bundle', 'exec', 'ruby', 'bench/puma_matrix_sweep.rb',
81
- '--iterations', '80',
82
- '--runs', '3'
83
- )
84
- end
85
-
86
77
  desc 'Run memory probe for single-instance vs duplicate-instance behavior'
87
78
  task :memory_probe do
88
79
  run_in_nix(
data/VERSION CHANGED
@@ -1 +1 @@
1
- 0.0.14
1
+ 0.0.16
data/ext/gte/Cargo.toml CHANGED
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "gte"
3
- version = "0.0.14"
3
+ version = "0.0.16"
4
4
  edition = "2021"
5
5
  authors = ["elcuervo <elcuervo@elcuervo.net>"]
6
6
  license = "MIT"
@@ -22,6 +22,7 @@ ruby-ffi = ["dep:magnus", "dep:rb-sys"]
22
22
  rb-sys = { version = "0.9", features = ["stable-api-compiled-fallback"], optional = true }
23
23
  magnus = { version = "0.8", optional = true }
24
24
  ort = { version = "=2.0.0-rc.12", features = ["ndarray", "xnnpack"] }
25
+ parking_lot = "0.12"
25
26
  tokenizers = "0.21.0"
26
27
  ndarray = "0.17"
27
28
  serde_json = "1"
@@ -5,8 +5,8 @@ use crate::model_profile::{
5
5
  resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
6
6
  };
7
7
  use crate::postprocess::normalize_l2 as normalize_l2_rows;
8
- use crate::session::{build_session, SessionPool};
9
- use crate::tokenizer::{parse_padding_mode_override, Tokenized, Tokenizer};
8
+ use crate::session::{build_session, resolve_pool_size, run_session, SessionPool};
9
+ use crate::tokenizer::{parse_padding_mode_override, Tokenizer};
10
10
  use ndarray::Array2;
11
11
  use std::path::{Path, PathBuf};
12
12
 
@@ -14,22 +14,10 @@ pub struct Embedder {
14
14
  tokenizer: Tokenizer,
15
15
  pool: SessionPool,
16
16
  pub config: ModelConfig,
17
+ normalize: bool,
17
18
  }
18
19
 
19
20
  impl Embedder {
20
- pub fn new<P1, P2>(tokenizer_path: P1, model_path: P2, config: ModelConfig) -> Result<Self>
21
- where
22
- P1: AsRef<Path>,
23
- P2: AsRef<Path>,
24
- {
25
- let tokenizer =
26
- Tokenizer::new(tokenizer_path, config.max_length, config.with_type_ids, config.padding_mode, None)?;
27
- let model_path = model_path.as_ref();
28
- let session = build_session(model_path, &config)?;
29
- let pool = SessionPool::new(session, model_path, &config)?;
30
- Ok(Self { tokenizer, pool, config })
31
- }
32
-
33
21
  pub fn from_dir<P: AsRef<Path>>(dir: P, optimization_level: u8, overrides: ModelLoadOverrides<'_>) -> Result<Self> {
34
22
  const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] =
35
23
  ["pooler_output", "text_embeds", "sentence_embedding", "last_hidden_state"];
@@ -52,7 +40,7 @@ impl Embedder {
52
40
  };
53
41
  let padding_mode = parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
54
42
 
55
- let session_config = ModelConfig {
43
+ let probe_config = ModelConfig {
56
44
  max_length,
57
45
  padding_mode,
58
46
  output_tensor: String::new(),
@@ -61,10 +49,8 @@ impl Embedder {
61
49
  with_attention_mask: true,
62
50
  optimization_level,
63
51
  execution_providers: overrides.execution_providers.map(str::to_string),
64
- lowercase_input: overrides.lowercase_input.unwrap_or(false),
65
- max_input_chars: overrides.max_input_chars,
66
52
  };
67
- let session = build_session(&model_path, &session_config)?;
53
+ let session = build_session(&model_path, &probe_config)?;
68
54
 
69
55
  validate_supported_text_inputs(&session, "text embedding")?;
70
56
  let with_type_ids = has_input(&session, "token_type_ids");
@@ -84,10 +70,10 @@ impl Embedder {
84
70
  with_attention_mask,
85
71
  optimization_level,
86
72
  execution_providers: overrides.execution_providers.map(str::to_string),
87
- lowercase_input: overrides.lowercase_input.unwrap_or(false),
88
- max_input_chars: overrides.max_input_chars,
89
73
  };
90
74
 
75
+ let normalize = should_normalize_output(&config.output_tensor);
76
+
91
77
  let tokenizer = Tokenizer::new(
92
78
  &tokenizer_path,
93
79
  config.max_length,
@@ -96,72 +82,50 @@ impl Embedder {
96
82
  tokenizer_profile.fixed_padding_length,
97
83
  )?;
98
84
 
99
- let pool = SessionPool::new(session, &model_path, &session_config)?;
100
- Ok(Self { tokenizer, pool, config })
85
+ let pool_size = resolve_pool_size();
86
+ let pool = SessionPool::new(&model_path, &config, pool_size)?;
87
+ Ok(Self { tokenizer, pool, config, normalize })
101
88
  }
102
89
 
103
90
  pub fn embed(&self, texts: &[String]) -> Result<Array2<f32>> {
104
- self.embed_ref(texts)
105
- }
106
-
107
- pub fn embed_ref(&self, texts: &[String]) -> Result<Array2<f32>> {
108
- let sanitized: Vec<String>;
109
- let input = if self.config.lowercase_input || self.config.max_input_chars.is_some() {
110
- sanitized = texts
111
- .iter()
112
- .map(|t| {
113
- let mut s = if self.config.lowercase_input { t.to_lowercase() } else { t.clone() };
114
- if let Some(max_chars) = self.config.max_input_chars {
115
- s.truncate(max_chars.min(s.len()));
116
- }
117
- s
118
- })
119
- .collect();
120
- &sanitized
91
+ let tokenized = self.tokenizer.tokenize(texts)?;
92
+ let embeddings = self.pool.with_session(|session| run_session(session, &tokenized, &self.config))?;
93
+ if self.normalize {
94
+ Ok(normalize_l2_rows(embeddings))
121
95
  } else {
122
- texts
123
- };
124
- let tokenized = self.tokenize(input)?;
125
- self.run(&tokenized)
96
+ Ok(embeddings)
97
+ }
126
98
  }
127
99
 
128
- pub fn tokenize(&self, texts: &[String]) -> Result<Tokenized> {
100
+ pub fn tokenize(&self, texts: &[String]) -> Result<crate::tokenizer::Tokenized> {
129
101
  self.tokenizer.tokenize(texts)
130
102
  }
131
-
132
- pub fn run(&self, tokenized: &Tokenized) -> Result<Array2<f32>> {
133
- self.pool.run(tokenized, &self.config)
134
- }
135
- }
136
-
137
- pub fn normalize_l2(embeddings: Array2<f32>) -> Array2<f32> {
138
- normalize_l2_rows(embeddings)
139
103
  }
140
104
 
141
- pub fn output_name_suggests_normalized(name: &str) -> bool {
105
+ fn should_normalize_output(name: &str) -> bool {
142
106
  let lower = name.to_ascii_lowercase();
143
107
  let base = lower.rsplit('/').next().unwrap_or(&lower);
144
- base.contains("normalized") || base.contains("l2_norm") || base.contains("l2norm")
108
+ !(base.contains("normalized") || base.contains("l2_norm") || base.contains("l2norm"))
145
109
  }
146
110
 
147
111
  #[cfg(test)]
148
112
  mod normalize_tests {
149
- use super::output_name_suggests_normalized;
113
+ use super::should_normalize_output;
150
114
 
151
115
  #[test]
152
116
  fn detects_normalized_output_names() {
153
- assert!(output_name_suggests_normalized("pooled_sentence_embeddings_debiased_normalized"));
154
- assert!(output_name_suggests_normalized("embeddings/L2_Normalized"));
155
- assert!(output_name_suggests_normalized("l2norm_output"));
156
- assert!(output_name_suggests_normalized("norm/l2_norm_tensor"));
117
+ assert!(!should_normalize_output("pooled_sentence_embeddings_debiased_normalized"));
118
+ assert!(!should_normalize_output("embeddings/L2_Normalized"));
119
+ assert!(!should_normalize_output("l2norm_output"));
120
+ assert!(!should_normalize_output("norm/l2_norm_tensor"));
157
121
  }
158
122
 
159
123
  #[test]
160
124
  fn does_not_detect_raw_output_names() {
161
- assert!(!output_name_suggests_normalized("last_hidden_state"));
162
- assert!(!output_name_suggests_normalized("text_embeds"));
163
- assert!(!output_name_suggests_normalized("pooler_output"));
164
- assert!(!output_name_suggests_normalized("sentence_embedding"));
165
- assert!(!output_name_suggests_normalized("logits"));
125
+ assert!(should_normalize_output("last_hidden_state"));
126
+ assert!(should_normalize_output("text_embeds"));
127
+ assert!(should_normalize_output("pooler_output"));
128
+ assert!(should_normalize_output("sentence_embedding"));
129
+ assert!(should_normalize_output("logits"));
166
130
  }
167
131
  }
data/ext/gte/src/lib.rs CHANGED
@@ -18,6 +18,7 @@ use magnus::{prelude::*, Error, Ruby};
18
18
  #[magnus::init]
19
19
  fn init(ruby: &Ruby) -> Result<(), Error> {
20
20
  let module = ruby.define_module("GTE")?;
21
+ #[allow(unused_results)]
21
22
  module.define_error("Error", ruby.exception_standard_error())?;
22
23
  ruby_embedder::register(ruby)?;
23
24
  std::panic::set_hook(Box::new(|info| {
@@ -23,8 +23,6 @@ pub struct ModelConfig {
23
23
  pub with_attention_mask: bool,
24
24
  pub optimization_level: u8,
25
25
  pub execution_providers: Option<String>,
26
- pub lowercase_input: bool,
27
- pub max_input_chars: Option<usize>,
28
26
  }
29
27
 
30
28
  #[derive(Debug, Clone, Copy, Default)]
@@ -34,6 +32,4 @@ pub struct ModelLoadOverrides<'a> {
34
32
  pub max_length: Option<usize>,
35
33
  pub padding: Option<&'a str>,
36
34
  pub execution_providers: Option<&'a str>,
37
- pub lowercase_input: Option<bool>,
38
- pub max_input_chars: Option<usize>,
39
35
  }
@@ -11,21 +11,20 @@ pub struct InputTensors<'a> {
11
11
 
12
12
  impl<'a> InputTensors<'a> {
13
13
  pub fn from_tokenized(tokenized: &'a Tokenized, with_attention_mask: bool) -> Result<Self> {
14
- let input_ids_view: ArrayView2<'_, i64> =
15
- ArrayView2::from_shape((tokenized.rows, tokenized.cols), tokenized.input_ids.as_slice())?;
16
- let attention_mask: ArrayView2<'_, i64> =
17
- ArrayView2::from_shape((tokenized.rows, tokenized.cols), tokenized.attn_masks.as_slice())?;
14
+ let input_ids_view = tokenized.input_ids.view();
15
+ let attention_mask = tokenized.attn_masks.view();
18
16
 
19
- let mut inputs = Vec::with_capacity(2 + usize::from(tokenized.type_ids.is_some()));
20
- inputs.push(("input_ids", SessionInputValue::from(TensorRef::from_array_view(input_ids_view)?)));
17
+ let mut inputs = Vec::with_capacity(2);
21
18
 
22
19
  if with_attention_mask {
20
+ inputs.push(("input_ids", SessionInputValue::from(TensorRef::from_array_view(input_ids_view)?)));
23
21
  inputs.push(("attention_mask", SessionInputValue::from(TensorRef::from_array_view(attention_mask)?)));
22
+ } else {
23
+ inputs.push(("input_ids", SessionInputValue::from(TensorRef::from_array_view(input_ids_view)?)));
24
24
  }
25
25
 
26
- if let Some(type_ids) = tokenized.type_ids.as_deref() {
27
- let type_ids_view: ArrayView2<'_, i64> =
28
- ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
26
+ if let Some(ref type_ids) = tokenized.type_ids {
27
+ let type_ids_view = type_ids.view();
29
28
  inputs.push(("token_type_ids", SessionInputValue::from(TensorRef::from_array_view(type_ids_view)?)));
30
29
  }
31
30
 
@@ -94,12 +94,13 @@ fn mean_pool_contiguous(
94
94
  if mask_row.iter().all(|&weight| weight == 1) {
95
95
  for token_index in 0..seq {
96
96
  let token_base = hidden_base + token_index * dim;
97
- for dim_index in 0..dim {
98
- output_row[dim_index] += hidden[token_base + dim_index];
97
+ let token_slice = &hidden[token_base..token_base + dim];
98
+ for (out, &h) in output_row.iter_mut().zip(token_slice.iter()) {
99
+ *out += h;
99
100
  }
100
101
  }
101
102
 
102
- for value in output_row {
103
+ for value in output_row.iter_mut() {
103
104
  *value *= seq_inverse;
104
105
  }
105
106
  continue;
@@ -114,15 +115,16 @@ fn mean_pool_contiguous(
114
115
 
115
116
  let weight = weight_raw as f32;
116
117
  let token_base = hidden_base + token_index * dim;
117
- for dim_index in 0..dim {
118
- output_row[dim_index] += hidden[token_base + dim_index] * weight;
118
+ let token_slice = &hidden[token_base..token_base + dim];
119
+ for (out, &h) in output_row.iter_mut().zip(token_slice.iter()) {
120
+ *out += h * weight;
119
121
  }
120
122
  weight_sum += weight;
121
123
  }
122
124
 
123
125
  if weight_sum > 0.0 {
124
126
  let inverse = weight_sum.recip();
125
- for value in output_row {
127
+ for value in output_row.iter_mut() {
126
128
  *value *= inverse;
127
129
  }
128
130
  }