gte 0.0.13 → 0.0.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 278028df09fbcdd14fd583f0af5e1a8c9553adb28fe7aa0bc67b67666dbbdccd
4
- data.tar.gz: ce994e3f505200ed4654ca8f87f585ff88919201fe82dd79007622f07a3d1ea0
3
+ metadata.gz: 37ad2ef3f640b8bbaefa14cae29541f329c3b99950a2867b9054b8e8854ca242
4
+ data.tar.gz: 0b54c757ca510ccc8644e0d3c1519aa7b6576fa1da4ba9012535e0b0b7d598dc
5
5
  SHA512:
6
- metadata.gz: 742f1830ff2b83f89726be527c4323a81649b04f341b7adc0544a9000373f6a097c0b4b4ba211ead5912ba45d876565fbaab6d723ef8f06c488ab7827323f827
7
- data.tar.gz: 75e91b3d4c3980b166268c6468b96bebe4b74db999e0cee433a295e57d89bec95c7614b004c61e8b3ed88cff30f02f3b6aff74de710d3dd3bb34552f36fb3422
6
+ metadata.gz: 5d67fc8c73aa2b162bf804082accc849ce972b4df51e2ec86bb7d977cd760e2bf5b760d7de141b5fbd5b6f68dd0e51d2bd15c84d1cb8dc9a8f4bf424b62490ba
7
+ data.tar.gz: c6cefec4d42f7ca72980e4d3454447aa681f05f0b72f96ceec226b5efd5ae32bade13d85ae55f032098d7cb7f72b530d8332f1eaf25fa364d0c33703da13e043
data/README.md CHANGED
@@ -58,6 +58,8 @@ Notes:
58
58
 
59
59
  - Return a `Config::Text` from the block (for example, `config.with(...)`).
60
60
  - Model instances are cached by full config key; different config values create different cached instances.
61
+ - `GTE.warmup(model, threads:)` pre-warms thread-local ONNX sessions eagerly at boot.
62
+ Useful in multi-threaded servers (Puma, Sidekiq) to avoid ~100-500ms cold-start latency.
61
63
 
62
64
  Common model presets:
63
65
 
@@ -73,7 +75,7 @@ end
73
75
 
74
76
  siglip2 = GTE.config(ENV.fetch("GTE_SIGLIP2_DIR")) do |config|
75
77
  config.with(
76
- model_name: "text_model_int8.onnx",
78
+ model_name: "text_model.onnx",
77
79
  output_tensor: "pooler_output",
78
80
  max_length: 64,
79
81
  execution_providers: "cpu"
@@ -147,6 +149,53 @@ Session pool sizing:
147
149
  - `GTE_SESSION_POOL_CAP`: optional positive integer cap for internal ONNX session pool size.
148
150
  - Unset by default; runtime uses available CPU parallelism.
149
151
 
152
+ ## Automatic Tuning
153
+
154
+ `gte` automatically adapts to the hardware — no configuration required.
155
+
156
+ ### ONNX Intra-op Threads
157
+
158
+ - Auto-detected via `std::thread::available_parallelism()` capped at 4.
159
+ - Prevents oversubscription on high-concurrency workloads.
160
+ - Override with `GTE_INTRA_OP_NUM_THREADS` env var.
161
+
162
+ ### ONNX Inter-op Threads
163
+
164
+ - Defaults to 1 (text embedding graphs are linear chains with no independent parallel nodes).
165
+ - Override with `GTE_INTER_OP_NUM_THREADS` env var.
166
+
167
+ ### Execution Providers
168
+
169
+ `gte` automatically tries XNNPACK for optimized CPU inference. Falls back to
170
+ ORT's default CPU provider if unavailable.
171
+
172
+ - **ARM64** (Apple Silicon, AWS Graviton): XNNPACK is typically **~25% faster**
173
+ than plain CPU while producing identical embeddings (cos=1.0, max_abs=0.0).
174
+ - **x86/x64** (Intel, AMD): XNNPACK offers minimal benefit — ORT's default CPU
175
+ provider already uses MKL-DNN/oneDNN, which are better tuned for these chips.
176
+ The auto-detect silently falls back to the default provider.
177
+
178
+ Configure providers explicitly with `GTE_EXECUTION_PROVIDERS` (comma-separated):
179
+
180
+ ```bash
181
+ export GTE_EXECUTION_PROVIDERS=xnnpack,coreml
182
+ ```
183
+
184
+ Set `cpu` or `none` to skip auto-detect and use ORT's default CPU provider.
185
+
186
+ ### Session Pre-Warming
187
+
188
+ ONNX sessions are created lazily per OS thread. In multi-threaded servers (Puma, Sidekiq),
189
+ each thread creates its own session on first use (~100-500ms cold start).
190
+ Pre-warm sessions eagerly at boot:
191
+
192
+ ```ruby
193
+ model = GTE.config(ENV.fetch("GTE_MODEL_DIR"))
194
+
195
+ # Pre-warm thread-local sessions for a Puma server with 5 threads:
196
+ GTE.warmup(model, threads: 5)
197
+ ```
198
+
150
199
  ## Runtime + Result Examples
151
200
 
152
201
  Process-local reuse (recommended for Puma/web servers):
@@ -170,47 +219,64 @@ A model directory must include `tokenizer.json` and one ONNX model, resolved in
170
219
 
171
220
  Input policy is text-only. Graphs requiring unsupported multimodal inputs (such as `pixel_values`) are intentionally rejected.
172
221
 
173
- ## Execution Providers
174
222
 
175
- Default behavior is CPU fallback via ONNX Runtime's default provider (no explicit provider registration).
223
+ ## Development
224
+
225
+ Run commands inside `nix develop` via Make targets:
226
+
227
+ ```bash
228
+ make setup
229
+ make compile
230
+ make test
231
+ make lint
232
+ make ci
233
+ ```
234
+
235
+ ## Benchmarks
176
236
 
177
- Configure providers with `GTE_EXECUTION_PROVIDERS` (comma-separated, case-insensitive).
178
- Supported values:
237
+ ### Docker Rails+Puma+wrk (Real-World HTTP)
179
238
 
180
- - `cpu` or `none`: CPU fallback (skip explicit provider registration)
181
- - `xnnpack`
182
- - `coreml`
239
+ The `bench/rails/` directory contains a full-stack benchmark: Rails 7.1 API app served by Puma,
240
+ loaded with wrk (randomized text queries, 135 diverse texts).
183
241
 
184
- Examples:
242
+ Run for all models:
185
243
 
186
244
  ```bash
187
- export GTE_EXECUTION_PROVIDERS=cpu
188
- export GTE_EXECUTION_PROVIDERS=xnnpack,coreml
245
+ make bench-docker-compare
189
246
  ```
190
247
 
191
- Ruby per-instance override (takes precedence over `GTE_EXECUTION_PROVIDERS`):
248
+ Run for a single model:
192
249
 
193
- ```ruby
194
- model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
195
- config.with(execution_providers: "cpu")
196
- end
250
+ ```bash
251
+ make bench-docker-sweep-siglip2
252
+ make bench-docker-validate # cross-validation checks
197
253
  ```
198
254
 
199
- ## Development
255
+ #### Siglip2 (768-dim, pooler_output)
200
256
 
201
- Run commands inside `nix develop` via Make targets:
257
+ | Concurrency | GTE p90 | Pure Ruby p90 | Ratio | GTE RPS | Pure Ruby RPS |
258
+ |------------|---------|---------------|-------|---------|---------------|
259
+ | c=1 | ~12ms | ~120ms | 9-10× | ~95 | ~10 |
260
+ | c=4 | ~39ms | ~503ms | 10-13× | ~228 | ~10 |
261
+ | c=8 | ~146ms | ~613ms | 3-4× | ~224 | ~10 |
262
+ | c=16 | ~430ms | ~611ms | 1-1.5× | ~226 | ~11 |
202
263
 
203
- ```bash
204
- make setup
205
- make compile
206
- make test
207
- make lint
208
- make ci
209
- ```
264
+ #### E5 (384-dim, last_hidden_state + mean pool)
265
+
266
+ | Concurrency | GTE p90 | Pure Ruby p90 | Ratio | GTE RPS | Pure Ruby RPS |
267
+ |------------|---------|---------------|-------|---------|---------------|
268
+ | c=1 | ~7ms | ~120ms | 16-17× | ~160 | ~10 |
269
+ | c=4 | ~12ms | ~430ms | 35-40× | ~477 | ~10 |
270
+ | c=8 | ~64ms | ~530ms | 8-9× | ~503 | ~10 |
271
+ | c=16 | ~205ms | ~534ms | 2-3× | ~509 | ~11 |
272
+
273
+ GTE releases the GVL during ONNX inference, enabling true parallelism across Puma threads.
274
+ Pure Ruby is GVL-bound (~10 RPS regardless of concurrency).
210
275
 
211
- ## Benchmark
276
+ The Puma thread pool (min=2, max=5) limits throughput at c=16+.
277
+ GTE's pipelining and GVL release already saturate the available threads at c=4.
212
278
 
213
- The repo includes a shared multi-runtime benchmark harness:
279
+ ### In-Process Benchmarks
214
280
 
215
281
  ```bash
216
282
  make bench
data/VERSION CHANGED
@@ -1 +1 @@
1
- 0.0.13
1
+ 0.0.14
data/ext/gte/Cargo.toml CHANGED
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "gte"
3
- version = "0.0.13"
3
+ version = "0.0.14"
4
4
  edition = "2021"
5
5
  authors = ["elcuervo <elcuervo@elcuervo.net>"]
6
6
  license = "MIT"
@@ -22,11 +22,8 @@ ruby-ffi = ["dep:magnus", "dep:rb-sys"]
22
22
  rb-sys = { version = "0.9", features = ["stable-api-compiled-fallback"], optional = true }
23
23
  magnus = { version = "0.8", optional = true }
24
24
  ort = { version = "=2.0.0-rc.12", features = ["ndarray", "xnnpack"] }
25
- ort-sys = "=2.0.0-rc.12"
26
25
  tokenizers = "0.21.0"
27
26
  ndarray = "0.17"
28
- half = "2"
29
- serde = { version = "1", features = ["derive"] }
30
27
  serde_json = "1"
31
28
 
32
29
  [dev-dependencies]
@@ -35,3 +32,28 @@ criterion = "0.5"
35
32
  [[bench]]
36
33
  name = "hot_path"
37
34
  harness = false
35
+
36
+ [lints.rust]
37
+ unsafe_code = "deny"
38
+ rust_2018_idioms = "warn"
39
+ unused_qualifications = "warn"
40
+ unused_results = "warn"
41
+
42
+ [lints.clippy]
43
+ all = "warn"
44
+ pedantic = "warn"
45
+ # Reasonable exceptions to pedantic lints:
46
+ module_name_repetitions = "allow"
47
+ missing_errors_doc = "allow"
48
+ must_use_candidate = "allow"
49
+ cast_possible_truncation = "allow"
50
+ cast_sign_loss = "allow"
51
+ cast_precision_loss = "allow"
52
+ similar_names = "allow"
53
+ too_many_lines = "allow"
54
+ # ndarray::ArrayView types are Copy — passing by value is idiomatic
55
+ needless_pass_by_value = "allow"
56
+ # ort::Outlet::name() is on a private type — closure is required
57
+ redundant_closure_for_method_calls = "allow"
58
+ # Transitive dep conflicts are not actionable in this crate
59
+ multiple_crate_versions = "allow"
@@ -5,9 +5,7 @@ use gte::postprocess::{mean_pool, normalize_l2};
5
5
  use ndarray::{Array2, Array3};
6
6
 
7
7
  fn build_hidden_states(batch: usize, seq: usize, dim: usize) -> Array3<f32> {
8
- Array3::from_shape_fn((batch, seq, dim), |(b, s, d)| {
9
- (((b * 31 + s * 17 + d * 13) % 97) as f32) / 97.0
10
- })
8
+ Array3::from_shape_fn((batch, seq, dim), |(b, s, d)| (((b * 31 + s * 17 + d * 13) % 97) as f32) / 97.0)
11
9
  }
12
10
 
13
11
  fn build_attention_mask(batch: usize, seq: usize) -> Array2<i64> {
@@ -22,15 +20,7 @@ fn bench_mean_pool(c: &mut Criterion) {
22
20
  group.bench_with_input(
23
21
  BenchmarkId::from_parameter(format!("{batch}x{seq}x{dim}")),
24
22
  &(batch, seq, dim),
25
- |b, _| {
26
- b.iter(|| {
27
- mean_pool(
28
- black_box(hidden_states.view()),
29
- black_box(attention_mask.view()),
30
- )
31
- .unwrap()
32
- })
33
- },
23
+ |b, _| b.iter(|| mean_pool(black_box(hidden_states.view()), black_box(attention_mask.view())).unwrap()),
34
24
  );
35
25
  }
36
26
  group.finish();
@@ -39,14 +29,10 @@ fn bench_mean_pool(c: &mut Criterion) {
39
29
  fn bench_normalize_l2(c: &mut Criterion) {
40
30
  let mut group = c.benchmark_group("normalize_l2");
41
31
  for (rows, dim) in [(1, 384), (8, 384), (32, 768), (128, 768)] {
42
- let embeddings = Array2::from_shape_fn((rows, dim), |(row, col)| {
43
- (((row * 19 + col * 7) % 113) as f32) / 113.0
32
+ let embeddings = Array2::from_shape_fn((rows, dim), |(row, col)| (((row * 19 + col * 7) % 113) as f32) / 113.0);
33
+ group.bench_with_input(BenchmarkId::from_parameter(format!("{rows}x{dim}")), &(rows, dim), |b, _| {
34
+ b.iter(|| normalize_l2(black_box(embeddings.clone())))
44
35
  });
45
- group.bench_with_input(
46
- BenchmarkId::from_parameter(format!("{rows}x{dim}")),
47
- &(rows, dim),
48
- |b, _| b.iter(|| normalize_l2(black_box(embeddings.clone()))),
49
- );
50
36
  }
51
37
  group.finish();
52
38
  }
@@ -61,26 +47,14 @@ fn bench_padding_impact(c: &mut Criterion) {
61
47
  let dim = 768;
62
48
  let mut group = c.benchmark_group("padding_impact");
63
49
 
64
- for (label, seq) in [
65
- ("batch_longest/4tok", 4usize),
66
- ("fixed/siglip2_max_64", 64usize),
67
- ("fixed/e5_max_512", 512usize),
68
- ] {
50
+ for (label, seq) in
51
+ [("batch_longest/4tok", 4usize), ("fixed/siglip2_max_64", 64usize), ("fixed/e5_max_512", 512usize)]
52
+ {
69
53
  let hidden_states = build_hidden_states(1, seq, dim);
70
54
  let attention_mask = build_attention_mask(1, seq);
71
- group.bench_with_input(
72
- BenchmarkId::from_parameter(label),
73
- &seq,
74
- |b, _| {
75
- b.iter(|| {
76
- mean_pool(
77
- black_box(hidden_states.view()),
78
- black_box(attention_mask.view()),
79
- )
80
- .unwrap()
81
- })
82
- },
83
- );
55
+ group.bench_with_input(BenchmarkId::from_parameter(label), &seq, |b, _| {
56
+ b.iter(|| mean_pool(black_box(hidden_states.view()), black_box(attention_mask.view())).unwrap())
57
+ });
84
58
  }
85
59
  group.finish();
86
60
  }
@@ -93,7 +67,12 @@ fn bench_padding_impact(c: &mut Criterion) {
93
67
  // Sweeps execution providers for quick local comparison.
94
68
  fn bench_embedding_e2e(c: &mut Criterion) {
95
69
  let cases = [
96
- ("e5", "GTE_BENCH_E5_DIR", "query: cat", "query: ".to_string() + &"the quick brown fox jumps over the lazy dog ".repeat(20)),
70
+ (
71
+ "e5",
72
+ "GTE_BENCH_E5_DIR",
73
+ "query: cat",
74
+ "query: ".to_string() + &"the quick brown fox jumps over the lazy dog ".repeat(20),
75
+ ),
97
76
  ("siglip2", "GTE_BENCH_SIGLIP2_DIR", "cat", "a photo of ".to_string() + &"a cat sitting on a mat ".repeat(10)),
98
77
  ("clip", "GTE_BENCH_CLIP_DIR", "cat", "a photo of ".to_string() + &"a cat sitting on a mat ".repeat(10)),
99
78
  ];
@@ -107,10 +86,7 @@ fn bench_embedding_e2e(c: &mut Criterion) {
107
86
  };
108
87
 
109
88
  for provider in ["cpu", "xnnpack"] {
110
- let overrides = ModelLoadOverrides {
111
- execution_providers: Some(provider),
112
- ..ModelLoadOverrides::default()
113
- };
89
+ let overrides = ModelLoadOverrides { execution_providers: Some(provider), ..ModelLoadOverrides::default() };
114
90
  let embedder = match Embedder::from_dir(&dir, 3, overrides) {
115
91
  Ok(e) => e,
116
92
  Err(err) => {
@@ -122,11 +98,7 @@ fn bench_embedding_e2e(c: &mut Criterion) {
122
98
  for (input_label, input) in [("short", short_input.to_string()), ("long", long_input.clone())] {
123
99
  let id = BenchmarkId::from_parameter(format!("{model_label}/{provider}/{input_label}"));
124
100
  group.bench_with_input(id, &input, |b, text| {
125
- b.iter(|| {
126
- embedder
127
- .embed(black_box(vec![text.clone()]))
128
- .expect("embed succeeds")
129
- })
101
+ b.iter(|| embedder.embed(black_box(&[text.clone()])).expect("embed succeeds"))
130
102
  });
131
103
  }
132
104
  }
@@ -134,11 +106,5 @@ fn bench_embedding_e2e(c: &mut Criterion) {
134
106
  group.finish();
135
107
  }
136
108
 
137
- criterion_group!(
138
- benches,
139
- bench_mean_pool,
140
- bench_normalize_l2,
141
- bench_padding_impact,
142
- bench_embedding_e2e
143
- );
109
+ criterion_group!(benches, bench_mean_pool, bench_normalize_l2, bench_padding_impact, bench_embedding_e2e);
144
110
  criterion_main!(benches);
data/ext/gte/build.rs CHANGED
@@ -1,15 +1,11 @@
1
1
  fn main() {
2
- let version = std::fs::read_to_string("../../VERSION")
3
- .expect("VERSION file not found")
4
- .trim()
5
- .to_string();
2
+ let version = std::fs::read_to_string("../../VERSION").expect("VERSION file not found").trim().to_string();
6
3
 
7
4
  let cargo_version = env!("CARGO_PKG_VERSION");
8
5
 
9
6
  assert_eq!(
10
7
  version, cargo_version,
11
- "VERSION file ({}) doesn't match Cargo.toml ({}). Update Cargo.toml to match.",
12
- version, cargo_version
8
+ "VERSION file ({version}) doesn't match Cargo.toml ({cargo_version}). Update Cargo.toml to match.",
13
9
  );
14
10
 
15
11
  println!("cargo:rerun-if-changed=../../VERSION");
@@ -0,0 +1,5 @@
1
+ edition = "2021"
2
+ max_width = 120
3
+ use_small_heuristics = "Max"
4
+ newline_style = "Unix"
5
+ tab_spaces = 4
@@ -1,11 +1,11 @@
1
1
  use crate::error::{GteError, Result};
2
2
  use crate::model_config::{ExtractorMode, ModelConfig, ModelLoadOverrides, PaddingMode};
3
3
  use crate::model_profile::{
4
- has_input, infer_extraction_mode, read_tokenizer_profile, resolve_default_text_model,
5
- resolve_named_model, resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
4
+ has_input, infer_extraction_mode, read_tokenizer_profile, resolve_default_text_model, resolve_named_model,
5
+ resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
6
6
  };
7
7
  use crate::postprocess::normalize_l2 as normalize_l2_rows;
8
- use crate::session::{build_session, run_session, SessionPool};
8
+ use crate::session::{build_session, SessionPool};
9
9
  use crate::tokenizer::{parse_padding_mode_override, Tokenized, Tokenizer};
10
10
  use ndarray::Array2;
11
11
  use std::path::{Path, PathBuf};
@@ -13,7 +13,7 @@ use std::path::{Path, PathBuf};
13
13
  pub struct Embedder {
14
14
  tokenizer: Tokenizer,
15
15
  pool: SessionPool,
16
- config: ModelConfig,
16
+ pub config: ModelConfig,
17
17
  }
18
18
 
19
19
  impl Embedder {
@@ -22,30 +22,17 @@ impl Embedder {
22
22
  P1: AsRef<Path>,
23
23
  P2: AsRef<Path>,
24
24
  {
25
- let tokenizer = Tokenizer::new(
26
- tokenizer_path,
27
- config.max_length,
28
- config.with_type_ids,
29
- config.padding_mode,
30
- None,
31
- )?;
32
- let model_path = model_path.as_ref().to_path_buf();
33
- let session = build_session(&model_path, &config)?;
34
- let pool = SessionPool::new(session, model_path, config.clone());
25
+ let tokenizer =
26
+ Tokenizer::new(tokenizer_path, config.max_length, config.with_type_ids, config.padding_mode, None)?;
27
+ let model_path = model_path.as_ref();
28
+ let session = build_session(model_path, &config)?;
29
+ let pool = SessionPool::new(session, model_path, &config)?;
35
30
  Ok(Self { tokenizer, pool, config })
36
31
  }
37
32
 
38
- pub fn from_dir<P: AsRef<Path>>(
39
- dir: P,
40
- optimization_level: u8,
41
- overrides: ModelLoadOverrides<'_>,
42
- ) -> Result<Self> {
43
- const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] = [
44
- "pooler_output",
45
- "text_embeds",
46
- "sentence_embedding",
47
- "last_hidden_state",
48
- ];
33
+ pub fn from_dir<P: AsRef<Path>>(dir: P, optimization_level: u8, overrides: ModelLoadOverrides<'_>) -> Result<Self> {
34
+ const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] =
35
+ ["pooler_output", "text_embeds", "sentence_embedding", "last_hidden_state"];
49
36
 
50
37
  let dir = dir.as_ref();
51
38
  let tokenizer_path = resolve_tokenizer_path(dir)?;
@@ -57,16 +44,13 @@ impl Embedder {
57
44
  let tokenizer_profile = read_tokenizer_profile(dir);
58
45
  let max_length = if let Some(override_value) = overrides.max_length {
59
46
  if override_value == 0 {
60
- return Err(GteError::Inference(
61
- "max_length override must be greater than 0".to_string(),
62
- ));
47
+ return Err(GteError::Inference("max_length override must be greater than 0".to_string()));
63
48
  }
64
49
  override_value.min(tokenizer_profile.safe_max_length)
65
50
  } else {
66
51
  tokenizer_profile.default_max_length
67
52
  };
68
- let padding_mode =
69
- parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
53
+ let padding_mode = parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
70
54
 
71
55
  let session_config = ModelConfig {
72
56
  max_length,
@@ -77,19 +61,18 @@ impl Embedder {
77
61
  with_attention_mask: true,
78
62
  optimization_level,
79
63
  execution_providers: overrides.execution_providers.map(str::to_string),
64
+ lowercase_input: overrides.lowercase_input.unwrap_or(false),
65
+ max_input_chars: overrides.max_input_chars,
80
66
  };
81
67
  let session = build_session(&model_path, &session_config)?;
82
68
 
83
69
  validate_supported_text_inputs(&session, "text embedding")?;
84
70
  let with_type_ids = has_input(&session, "token_type_ids");
85
71
  let with_attention_mask = has_input(&session, "attention_mask");
86
- let output_tensor =
87
- select_output_tensor(&session, overrides.output_tensor, &PREFERRED_EMBEDDING_OUTPUTS)?;
72
+ let output_tensor = select_output_tensor(&session, overrides.output_tensor, &PREFERRED_EMBEDDING_OUTPUTS)?;
88
73
  let mode = infer_extraction_mode(&session, output_tensor.as_str())?;
89
74
  if matches!(mode, ExtractorMode::MeanPool) && !with_attention_mask {
90
- return Err(GteError::Inference(
91
- "cannot use mean pooling without attention_mask input".to_string(),
92
- ));
75
+ return Err(GteError::Inference("cannot use mean pooling without attention_mask input".to_string()));
93
76
  }
94
77
 
95
78
  let config = ModelConfig {
@@ -101,6 +84,8 @@ impl Embedder {
101
84
  with_attention_mask,
102
85
  optimization_level,
103
86
  execution_providers: overrides.execution_providers.map(str::to_string),
87
+ lowercase_input: overrides.lowercase_input.unwrap_or(false),
88
+ max_input_chars: overrides.max_input_chars,
104
89
  };
105
90
 
106
91
  let tokenizer = Tokenizer::new(
@@ -111,29 +96,72 @@ impl Embedder {
111
96
  tokenizer_profile.fixed_padding_length,
112
97
  )?;
113
98
 
114
- let pool = SessionPool::new(session, model_path, session_config);
99
+ let pool = SessionPool::new(session, &model_path, &session_config)?;
115
100
  Ok(Self { tokenizer, pool, config })
116
101
  }
117
102
 
118
- pub fn embed(&self, texts: Vec<String>) -> Result<Array2<f32>> {
119
- self.embed_ref(&texts)
103
+ pub fn embed(&self, texts: &[String]) -> Result<Array2<f32>> {
104
+ self.embed_ref(texts)
120
105
  }
121
106
 
122
107
  pub fn embed_ref(&self, texts: &[String]) -> Result<Array2<f32>> {
123
- let tokenized = self.tokenize(texts)?;
108
+ let sanitized: Vec<String>;
109
+ let input = if self.config.lowercase_input || self.config.max_input_chars.is_some() {
110
+ sanitized = texts
111
+ .iter()
112
+ .map(|t| {
113
+ let mut s = if self.config.lowercase_input { t.to_lowercase() } else { t.clone() };
114
+ if let Some(max_chars) = self.config.max_input_chars {
115
+ s.truncate(max_chars.min(s.len()));
116
+ }
117
+ s
118
+ })
119
+ .collect();
120
+ &sanitized
121
+ } else {
122
+ texts
123
+ };
124
+ let tokenized = self.tokenize(input)?;
124
125
  self.run(&tokenized)
125
126
  }
126
127
 
127
- pub fn tokenize(&self, texts: &[String]) -> crate::error::Result<Tokenized> {
128
+ pub fn tokenize(&self, texts: &[String]) -> Result<Tokenized> {
128
129
  self.tokenizer.tokenize(texts)
129
130
  }
130
131
 
131
- pub fn run(&self, tokenized: &Tokenized) -> crate::error::Result<Array2<f32>> {
132
- let mut session = self.pool.acquire()?;
133
- run_session(&mut session, tokenized, &self.config)
132
+ pub fn run(&self, tokenized: &Tokenized) -> Result<Array2<f32>> {
133
+ self.pool.run(tokenized, &self.config)
134
134
  }
135
135
  }
136
136
 
137
137
  pub fn normalize_l2(embeddings: Array2<f32>) -> Array2<f32> {
138
138
  normalize_l2_rows(embeddings)
139
139
  }
140
+
141
+ pub fn output_name_suggests_normalized(name: &str) -> bool {
142
+ let lower = name.to_ascii_lowercase();
143
+ let base = lower.rsplit('/').next().unwrap_or(&lower);
144
+ base.contains("normalized") || base.contains("l2_norm") || base.contains("l2norm")
145
+ }
146
+
147
+ #[cfg(test)]
148
+ mod normalize_tests {
149
+ use super::output_name_suggests_normalized;
150
+
151
+ #[test]
152
+ fn detects_normalized_output_names() {
153
+ assert!(output_name_suggests_normalized("pooled_sentence_embeddings_debiased_normalized"));
154
+ assert!(output_name_suggests_normalized("embeddings/L2_Normalized"));
155
+ assert!(output_name_suggests_normalized("l2norm_output"));
156
+ assert!(output_name_suggests_normalized("norm/l2_norm_tensor"));
157
+ }
158
+
159
+ #[test]
160
+ fn does_not_detect_raw_output_names() {
161
+ assert!(!output_name_suggests_normalized("last_hidden_state"));
162
+ assert!(!output_name_suggests_normalized("text_embeds"));
163
+ assert!(!output_name_suggests_normalized("pooler_output"));
164
+ assert!(!output_name_suggests_normalized("sentence_embedding"));
165
+ assert!(!output_name_suggests_normalized("logits"));
166
+ }
167
+ }
data/ext/gte/src/error.rs CHANGED
@@ -9,10 +9,10 @@ pub enum GteError {
9
9
  impl std::fmt::Display for GteError {
10
10
  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
11
11
  match self {
12
- GteError::Tokenizer(msg) => write!(f, "GTE tokenizer error: {}", msg),
13
- GteError::Inference(msg) => write!(f, "GTE inference error: {}", msg),
14
- GteError::Ort(msg) => write!(f, "GTE ORT error: {}", msg),
15
- GteError::Shape(msg) => write!(f, "GTE shape error: {}", msg),
12
+ GteError::Tokenizer(msg) => write!(f, "GTE tokenizer error: {msg}"),
13
+ GteError::Inference(msg) => write!(f, "GTE inference error: {msg}"),
14
+ GteError::Ort(msg) => write!(f, "GTE ORT error: {msg}"),
15
+ GteError::Shape(msg) => write!(f, "GTE shape error: {msg}"),
16
16
  }
17
17
  }
18
18
  }
data/ext/gte/src/lib.rs CHANGED
@@ -19,7 +19,7 @@ use magnus::{prelude::*, Error, Ruby};
19
19
  fn init(ruby: &Ruby) -> Result<(), Error> {
20
20
  let module = ruby.define_module("GTE")?;
21
21
  module.define_error("Error", ruby.exception_standard_error())?;
22
- crate::ruby_embedder::register(ruby)?;
22
+ ruby_embedder::register(ruby)?;
23
23
  std::panic::set_hook(Box::new(|info| {
24
24
  let msg = info
25
25
  .payload()
@@ -23,6 +23,8 @@ pub struct ModelConfig {
23
23
  pub with_attention_mask: bool,
24
24
  pub optimization_level: u8,
25
25
  pub execution_providers: Option<String>,
26
+ pub lowercase_input: bool,
27
+ pub max_input_chars: Option<usize>,
26
28
  }
27
29
 
28
30
  #[derive(Debug, Clone, Copy, Default)]
@@ -32,4 +34,6 @@ pub struct ModelLoadOverrides<'a> {
32
34
  pub max_length: Option<usize>,
33
35
  pub padding: Option<&'a str>,
34
36
  pub execution_providers: Option<&'a str>,
37
+ pub lowercase_input: Option<bool>,
38
+ pub max_input_chars: Option<usize>,
35
39
  }