gte 0.0.11 → 0.0.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8123419de1a0fe86fae1de8808354318ca5b8d575514c7452ac45f71b657252e
4
- data.tar.gz: da92e0ba5cab358dde743f4b444c5500f88757268084138571bd578bc505f1f3
3
+ metadata.gz: 278028df09fbcdd14fd583f0af5e1a8c9553adb28fe7aa0bc67b67666dbbdccd
4
+ data.tar.gz: ce994e3f505200ed4654ca8f87f585ff88919201fe82dd79007622f07a3d1ea0
5
5
  SHA512:
6
- metadata.gz: 1fd3dc5a8a7e005d797f352c3fefea348a6badc529eca0169b9d349ffda0aa707f0327897112e019340091b2613e6948210966676ddc6717e72146f506c93fae
7
- data.tar.gz: ea12d947a02133c69990f24c104a2003df465149fa8beb3ec857ced623b3ba39310f160d956c3da40636a58e92e50720e37c3845905db74dc772ca8b4a6d0c85
6
+ metadata.gz: 742f1830ff2b83f89726be527c4323a81649b04f341b7adc0544a9000373f6a097c0b4b4ba211ead5912ba45d876565fbaab6d723ef8f06c488ab7827323f827
7
+ data.tar.gz: 75e91b3d4c3980b166268c6468b96bebe4b74db999e0cee433a295e57d89bec95c7614b004c61e8b3ed88cff30f02f3b6aff74de710d3dd3bb34552f36fb3422
data/README.md CHANGED
@@ -33,10 +33,6 @@ raw_model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
33
33
  config.with(normalize: false)
34
34
  end
35
35
 
36
- single_thread = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
37
- config.with(threads: 1)
38
- end
39
-
40
36
  custom = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
41
37
  config.with(
42
38
  output_tensor: "last_hidden_state",
@@ -50,7 +46,6 @@ end
50
46
  Config fields and defaults:
51
47
 
52
48
  - `model_dir`: absolute path to model directory
53
- - `threads`: `1` (default tuned for p95 latency; use `0` for ONNX Runtime auto-thread mode)
54
49
  - `optimization_level`: `3`
55
50
  - `model_name`: `nil`
56
51
  - `normalize`: `true` (L2 normalization at Ruby-facing API)
@@ -64,11 +59,48 @@ Notes:
64
59
  - Return a `Config::Text` from the block (for example, `config.with(...)`).
65
60
  - Model instances are cached by full config key; different config values create different cached instances.
66
61
 
62
+ Common model presets:
63
+
64
+ ```ruby
65
+ e5 = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
66
+ config.with(
67
+ model_name: "model.onnx",
68
+ output_tensor: "last_hidden_state",
69
+ max_length: 512,
70
+ execution_providers: "cpu"
71
+ )
72
+ end
73
+
74
+ siglip2 = GTE.config(ENV.fetch("GTE_SIGLIP2_DIR")) do |config|
75
+ config.with(
76
+ model_name: "text_model_int8.onnx",
77
+ output_tensor: "pooler_output",
78
+ max_length: 64,
79
+ execution_providers: "cpu"
80
+ )
81
+ end
82
+
83
+ clip = GTE.config(ENV.fetch("GTE_CLIP_DIR")) do |config|
84
+ config.with(
85
+ output_tensor: "sentence_embedding",
86
+ max_length: 512,
87
+ execution_providers: "cpu"
88
+ )
89
+ end
90
+ ```
91
+
92
+ Picking a specific layer:
93
+
94
+ - Use `output_tensor:` to request a named model output.
95
+ - `last_hidden_state` gives token-level hidden states and is mean-pooled by `gte` when the tensor is rank 3.
96
+ - `pooler_output`, `sentence_embedding`, and similar 2D tensors are returned directly and then L2-normalized by default.
97
+ - If the requested tensor is not present in the model, `gte` raises an error instead of silently falling back.
98
+
67
99
  Low-level embedder setup (without model cache):
68
100
 
69
101
  ```ruby
70
102
  embedder = GTE::Embedder.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
71
- config.with(threads: 1, execution_providers: "cpu")
103
+ config.with(execution_providers: "cpu")
72
104
  end
73
105
  ```
74
106
 
@@ -78,7 +110,7 @@ Use `GTE::Reranker.config(model_dir)` for cross-encoder reranking.
78
110
 
79
111
  ```ruby
80
112
  reranker = GTE::Reranker.config(ENV.fetch("GTE_RERANK_DIR")) do |config|
81
- config.with(sigmoid: true, threads: 1)
113
+ config.with(sigmoid: true)
82
114
  end
83
115
 
84
116
  query = "how to train a neural network?"
@@ -102,7 +134,6 @@ ranked = reranker.rerank(query: query, candidates: candidates)
102
134
  Reranker config fields and defaults:
103
135
 
104
136
  - `model_dir`: absolute path to model directory
105
- - `threads`: `1`
106
137
  - `optimization_level`: `3`
107
138
  - `model_name`: `nil`
108
139
  - `sigmoid`: `false` (set `true` if you want bounded [0,1] style scores)
@@ -111,6 +142,11 @@ Reranker config fields and defaults:
111
142
  - `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
112
143
  - `execution_providers`: `nil`
113
144
 
145
+ Session pool sizing:
146
+
147
+ - `GTE_SESSION_POOL_CAP`: optional positive integer cap for internal ONNX session pool size.
148
+ - Unset by default; runtime uses available CPU parallelism.
149
+
114
150
  ## Runtime + Result Examples
115
151
 
116
152
  Process-local reuse (recommended for Puma/web servers):
@@ -185,7 +221,7 @@ nix develop -c bundle exec ruby bench/memory_probe.rb --compare-pure
185
221
 
186
222
  - `make bench`: Puma-like single-request comparison at concurrency `16`
187
223
  - `rake bench:pure_compare`: batch amortization comparison
188
- - `rake bench:matrix_sweep`: GTE provider/thread sweep using the shared result schema
224
+ - `rake bench:matrix_sweep`: GTE provider sweep using the shared result schema
189
225
  - Optional Python comparisons use `bench/python_onnxruntime.py` and are skipped automatically if local dependencies are unavailable.
190
226
 
191
227
  To run benchmark + append a `RUNS.md` entry + enforce goal checks:
data/Rakefile CHANGED
@@ -74,7 +74,7 @@ namespace :bench do
74
74
  )
75
75
  end
76
76
 
77
- desc 'Sweep execution-provider and thread settings for Puma-like benchmark'
77
+ desc 'Sweep execution-provider settings for Puma-like benchmark'
78
78
  task :matrix_sweep do
79
79
  run_in_nix(
80
80
  'bundle', 'exec', 'ruby', 'bench/puma_matrix_sweep.rb',
data/VERSION CHANGED
@@ -1 +1 @@
1
- 0.0.11
1
+ 0.0.13
data/ext/gte/Cargo.toml CHANGED
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "gte"
3
- version = "0.0.11"
3
+ version = "0.0.13"
4
4
  edition = "2021"
5
5
  authors = ["elcuervo <elcuervo@elcuervo.net>"]
6
6
  license = "MIT"
@@ -90,7 +90,7 @@ fn bench_padding_impact(c: &mut Criterion) {
90
90
  // GTE_BENCH_E5_DIR — sentence-transformers / E5-style text model dir
91
91
  // GTE_BENCH_SIGLIP2_DIR — siglip2 text encoder dir
92
92
  // GTE_BENCH_CLIP_DIR — clip text encoder dir
93
- // Sweeps threads {0 (auto/all-cores), 1, 2} to validate DEFAULT_THREADS=0.
93
+ // Sweeps execution providers for quick local comparison.
94
94
  fn bench_embedding_e2e(c: &mut Criterion) {
95
95
  let cases = [
96
96
  ("e5", "GTE_BENCH_E5_DIR", "query: cat", "query: ".to_string() + &"the quick brown fox jumps over the lazy dog ".repeat(20)),
@@ -106,17 +106,21 @@ fn bench_embedding_e2e(c: &mut Criterion) {
106
106
  continue;
107
107
  };
108
108
 
109
- for &threads in &[0usize, 1, 2] {
110
- let embedder = match Embedder::from_dir(&dir, threads, 3, ModelLoadOverrides::default()) {
109
+ for provider in ["cpu", "xnnpack"] {
110
+ let overrides = ModelLoadOverrides {
111
+ execution_providers: Some(provider),
112
+ ..ModelLoadOverrides::default()
113
+ };
114
+ let embedder = match Embedder::from_dir(&dir, 3, overrides) {
111
115
  Ok(e) => e,
112
116
  Err(err) => {
113
- eprintln!("skip {model_label} threads={threads}: {err}");
117
+ eprintln!("skip {model_label} provider={provider}: {err}");
114
118
  continue;
115
119
  }
116
120
  };
117
121
 
118
122
  for (input_label, input) in [("short", short_input.to_string()), ("long", long_input.clone())] {
119
- let id = BenchmarkId::from_parameter(format!("{model_label}/threads_{threads}/{input_label}"));
123
+ let id = BenchmarkId::from_parameter(format!("{model_label}/{provider}/{input_label}"));
120
124
  group.bench_with_input(id, &input, |b, text| {
121
125
  b.iter(|| {
122
126
  embedder
@@ -37,7 +37,6 @@ impl Embedder {
37
37
 
38
38
  pub fn from_dir<P: AsRef<Path>>(
39
39
  dir: P,
40
- num_threads: usize,
41
40
  optimization_level: u8,
42
41
  overrides: ModelLoadOverrides<'_>,
43
42
  ) -> Result<Self> {
@@ -76,7 +75,6 @@ impl Embedder {
76
75
  mode: ExtractorMode::Raw,
77
76
  with_type_ids: false,
78
77
  with_attention_mask: true,
79
- num_threads,
80
78
  optimization_level,
81
79
  execution_providers: overrides.execution_providers.map(str::to_string),
82
80
  };
@@ -101,7 +99,6 @@ impl Embedder {
101
99
  mode,
102
100
  with_type_ids,
103
101
  with_attention_mask,
104
- num_threads,
105
102
  optimization_level,
106
103
  execution_providers: overrides.execution_providers.map(str::to_string),
107
104
  };
@@ -119,7 +116,11 @@ impl Embedder {
119
116
  }
120
117
 
121
118
  pub fn embed(&self, texts: Vec<String>) -> Result<Array2<f32>> {
122
- let tokenized = self.tokenize(&texts)?;
119
+ self.embed_ref(&texts)
120
+ }
121
+
122
+ pub fn embed_ref(&self, texts: &[String]) -> Result<Array2<f32>> {
123
+ let tokenized = self.tokenize(texts)?;
123
124
  self.run(&tokenized)
124
125
  }
125
126
 
@@ -21,7 +21,6 @@ pub struct ModelConfig {
21
21
  pub mode: ExtractorMode,
22
22
  pub with_type_ids: bool,
23
23
  pub with_attention_mask: bool,
24
- pub num_threads: usize,
25
24
  pub optimization_level: u8,
26
25
  pub execution_providers: Option<String>,
27
26
  }
@@ -201,10 +201,14 @@ pub fn select_output_tensor(
201
201
  }
202
202
  }
203
203
 
204
- session
205
- .outputs()
206
- .first()
207
- .map(|o| o.name().to_owned())
204
+ let outputs = session.outputs();
205
+ let best = outputs
206
+ .iter()
207
+ .find(|o| {
208
+ matches!(o.dtype(), ort::value::ValueType::Tensor { shape, .. } if shape.len() == 2)
209
+ })
210
+ .or_else(|| outputs.first());
211
+ best.map(|o| o.name().to_owned())
208
212
  .ok_or_else(|| GteError::Inference("model has no outputs".into()))
209
213
  }
210
214
 
@@ -28,7 +28,6 @@ pub struct Reranker {
28
28
  impl Reranker {
29
29
  pub fn from_dir<P: AsRef<Path>>(
30
30
  dir: P,
31
- num_threads: usize,
32
31
  optimization_level: u8,
33
32
  overrides: ModelLoadOverrides<'_>,
34
33
  ) -> Result<Self> {
@@ -60,7 +59,6 @@ impl Reranker {
60
59
  mode: crate::model_config::ExtractorMode::Raw,
61
60
  with_type_ids: false,
62
61
  with_attention_mask: true,
63
- num_threads,
64
62
  optimization_level,
65
63
  execution_providers: overrides.execution_providers.map(str::to_string),
66
64
  };
@@ -4,7 +4,6 @@ use crate::embedder::{normalize_l2, Embedder};
4
4
  use crate::error::GteError;
5
5
  use crate::model_config::ModelLoadOverrides;
6
6
  use crate::reranker::Reranker;
7
- use crate::tokenizer::Tokenized;
8
7
  use magnus::{function, method, prelude::*, wrap, Error, RArray, Ruby};
9
8
  use std::os::raw::c_void;
10
9
  use std::panic::{catch_unwind, AssertUnwindSafe};
@@ -33,10 +32,9 @@ pub struct RbTensor {
33
32
  // GVL-release helpers
34
33
  // ---------------------------------------------------------------------------
35
34
 
36
- // Tokenized holds only Vec<i64> fields — safe to send across threads.
37
35
  struct InferArgs {
38
36
  embedder: *const Embedder,
39
- tokenized: *const Tokenized,
37
+ texts: *const Vec<String>,
40
38
  normalize: bool,
41
39
  result: Option<crate::error::Result<ndarray::Array2<f32>>>,
42
40
  }
@@ -45,7 +43,8 @@ unsafe impl Send for InferArgs {}
45
43
 
46
44
  struct ScoreArgs {
47
45
  reranker: *const Reranker,
48
- pairs: *const Vec<(String, String)>,
46
+ query: *const String,
47
+ candidates: *const Vec<String>,
49
48
  apply_sigmoid: bool,
50
49
  result: Option<crate::error::Result<Vec<f32>>>,
51
50
  }
@@ -62,12 +61,11 @@ fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
62
61
  }
63
62
  }
64
63
 
65
- unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
64
+ unsafe extern "C" fn run_embed_without_gvl(ptr: *mut c_void) -> *mut c_void {
66
65
  let args = &mut *(ptr as *mut InferArgs);
67
66
  let run_result = catch_unwind(AssertUnwindSafe(|| {
68
- // Tokenization happens before GVL release (in rb_embed / rb_embed_one).
69
- // Only ONNX inference runs here without the GVL.
70
- let embeddings = (*args.embedder).run(&*args.tokenized)?;
67
+ // Full embedding path (tokenization + inference) runs without the GVL.
68
+ let embeddings = (*args.embedder).embed_ref(&*args.texts)?;
71
69
  if args.normalize { Ok(normalize_l2(embeddings)) } else { Ok(embeddings) }
72
70
  }));
73
71
  args.result = Some(match run_result {
@@ -83,7 +81,7 @@ unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
83
81
  unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
84
82
  let args = &mut *(ptr as *mut ScoreArgs);
85
83
  let run_result = catch_unwind(AssertUnwindSafe(|| {
86
- (*args.reranker).score_pairs(&*args.pairs, args.apply_sigmoid)
84
+ (*args.reranker).score(&*args.query, &*args.candidates, args.apply_sigmoid)
87
85
  }));
88
86
  args.result = Some(match run_result {
89
87
  Ok(result) => result,
@@ -98,17 +96,17 @@ unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
98
96
  fn infer_without_gvl(
99
97
  embedder: &Arc<Embedder>,
100
98
  normalize: bool,
101
- tokenized: &Tokenized,
99
+ texts: Vec<String>,
102
100
  ) -> Result<ndarray::Array2<f32>, Error> {
103
101
  let embeddings = unsafe {
104
102
  let mut args = InferArgs {
105
103
  embedder: Arc::as_ptr(embedder),
106
- tokenized: tokenized as *const Tokenized,
104
+ texts: &texts as *const Vec<String>,
107
105
  normalize,
108
106
  result: None,
109
107
  };
110
108
  rb_sys::rb_thread_call_without_gvl(
111
- Some(run_without_gvl),
109
+ Some(run_embed_without_gvl),
112
110
  &mut args as *mut InferArgs as *mut c_void,
113
111
  None,
114
112
  std::ptr::null_mut(),
@@ -125,13 +123,15 @@ fn infer_without_gvl(
125
123
 
126
124
  fn score_without_gvl(
127
125
  reranker: &Arc<Reranker>,
128
- pairs: Vec<(String, String)>,
126
+ query: String,
127
+ candidates: Vec<String>,
129
128
  apply_sigmoid: bool,
130
129
  ) -> Result<Vec<f32>, Error> {
131
130
  let scores = unsafe {
132
131
  let mut args = ScoreArgs {
133
132
  reranker: Arc::as_ptr(reranker),
134
- pairs: &pairs as *const Vec<(String, String)>,
133
+ query: &query as *const String,
134
+ candidates: &candidates as *const Vec<String>,
135
135
  apply_sigmoid,
136
136
  result: None,
137
137
  };
@@ -170,7 +170,6 @@ impl RbEmbedder {
170
170
  pub fn rb_new(
171
171
  _ruby: &Ruby,
172
172
  dir_path: String,
173
- num_threads: usize,
174
173
  optimization_level: u8,
175
174
  model_name: String,
176
175
  normalize: bool,
@@ -191,21 +190,19 @@ impl RbEmbedder {
191
190
  padding: padding_override,
192
191
  execution_providers: execution_providers_override,
193
192
  };
194
- let embedder = Embedder::from_dir(&dir_path, num_threads, optimization_level, overrides)
193
+ let embedder = Embedder::from_dir(&dir_path, optimization_level, overrides)
195
194
  .map_err(magnus::Error::from)?;
196
195
  Ok(RbEmbedder { inner: Arc::new(embedder), normalize })
197
196
  }
198
197
 
199
198
  pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
200
199
  let texts: Vec<String> = texts.to_vec()?;
201
- let tokenized = rb_self.inner.tokenize(&texts).map_err(magnus::Error::from)?;
202
- let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, &tokenized)?;
200
+ let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, texts)?;
203
201
  tensor_from_array(embeddings)
204
202
  }
205
203
 
206
204
  pub fn rb_embed_one(_ruby: &Ruby, rb_self: &Self, text: String) -> Result<RbTensor, Error> {
207
- let tokenized = rb_self.inner.tokenize(&[text]).map_err(magnus::Error::from)?;
208
- let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, &tokenized)?;
205
+ let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, vec![text])?;
209
206
  tensor_from_array(embeddings)
210
207
  }
211
208
  }
@@ -214,7 +211,6 @@ impl RbReranker {
214
211
  pub fn rb_new(
215
212
  _ruby: &Ruby,
216
213
  dir_path: String,
217
- num_threads: usize,
218
214
  optimization_level: u8,
219
215
  model_name: String,
220
216
  sigmoid: bool,
@@ -235,7 +231,7 @@ impl RbReranker {
235
231
  padding: padding_override,
236
232
  execution_providers: execution_providers_override,
237
233
  };
238
- let reranker = Reranker::from_dir(&dir_path, num_threads, optimization_level, overrides)
234
+ let reranker = Reranker::from_dir(&dir_path, optimization_level, overrides)
239
235
  .map_err(magnus::Error::from)?;
240
236
  Ok(RbReranker { inner: Arc::new(reranker), sigmoid })
241
237
  }
@@ -247,8 +243,7 @@ impl RbReranker {
247
243
  candidates: RArray,
248
244
  ) -> Result<RArray, Error> {
249
245
  let candidates: Vec<String> = candidates.to_vec()?;
250
- let pairs: Vec<(String, String)> = candidates.into_iter().map(|c| (query.clone(), c)).collect();
251
- let scores = score_without_gvl(&rb_self.inner, pairs, rb_self.sigmoid)?;
246
+ let scores = score_without_gvl(&rb_self.inner, query, candidates, rb_self.sigmoid)?;
252
247
  let out = ruby.ary_new_capa(scores.len());
253
248
  for score in scores {
254
249
  out.push(score)?;
@@ -345,12 +340,12 @@ impl RbTensor {
345
340
  pub fn register(ruby: &Ruby) -> Result<(), Error> {
346
341
  let module = ruby.define_module("GTE")?;
347
342
  let embedder_class = module.define_class("Embedder", ruby.class_object())?;
348
- embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 9))?;
343
+ embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 8))?;
349
344
  embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
350
345
  embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
351
346
 
352
347
  let reranker_class = module.define_class("Reranker", ruby.class_object())?;
353
- reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 9))?;
348
+ reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 8))?;
354
349
  reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
355
350
 
356
351
  let tensor_class = module.define_class("Tensor", ruby.class_object())?;
@@ -7,7 +7,7 @@ use ndarray::{Array2, ArrayView2, ArrayViewD, Ix2};
7
7
  use ort::execution_providers::{
8
8
  CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider,
9
9
  };
10
- use ort::session::Session;
10
+ use ort::session::{OutputSelector, RunOptions, Session};
11
11
  use std::path::{Path, PathBuf};
12
12
  use std::sync::atomic::{AtomicUsize, Ordering};
13
13
  use std::sync::{Condvar, Mutex};
@@ -27,8 +27,6 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
27
27
  let mut builder = Session::builder()
28
28
  .map_err(ort_err)?
29
29
  .with_optimization_level(opt_level)
30
- .map_err(ort_err)?
31
- .with_memory_pattern(false)
32
30
  .map_err(ort_err)?;
33
31
 
34
32
  let providers = preferred_execution_providers(config.execution_providers.as_deref());
@@ -38,15 +36,6 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
38
36
  .map_err(ort_err)?;
39
37
  }
40
38
 
41
- if config.num_threads > 0 {
42
- builder = builder
43
- .with_intra_threads(config.num_threads)
44
- .map_err(ort_err)?;
45
- builder = builder
46
- .with_inter_threads(config.num_threads)
47
- .map_err(ort_err)?;
48
- }
49
-
50
39
  builder.commit_from_file(model_path).map_err(ort_err)
51
40
  }
52
41
 
@@ -54,25 +43,17 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
54
43
  // Session pool
55
44
  // ---------------------------------------------------------------------------
56
45
 
57
- fn pool_capacity(num_threads: usize) -> usize {
58
- let available_parallelism = std::thread::available_parallelism()
46
+ fn pool_capacity() -> usize {
47
+ let available = std::thread::available_parallelism()
59
48
  .map(|n| n.get())
60
49
  .unwrap_or(1);
61
- pool_capacity_with_parallelism(num_threads, available_parallelism)
50
+ parse_pool_capacity_override().map_or(available, |cap| cap.min(available).max(1))
62
51
  }
63
52
 
64
- fn pool_capacity_with_parallelism(num_threads: usize, available_parallelism: usize) -> usize {
65
- if available_parallelism == 0 {
66
- return 1;
67
- }
68
-
69
- // Auto-thread mode: ORT grabs all cores per session. One session avoids
70
- // N² intra-op oversubscription when multiple Ruby threads call concurrently.
71
- if num_threads == 0 {
72
- return 1;
73
- }
74
-
75
- available_parallelism.div_ceil(num_threads).max(1)
53
+ fn parse_pool_capacity_override() -> Option<usize> {
54
+ let raw = std::env::var("GTE_SESSION_POOL_CAP").ok()?;
55
+ let parsed = raw.trim().parse::<usize>().ok()?;
56
+ (parsed > 0).then_some(parsed)
76
57
  }
77
58
 
78
59
  pub struct SessionPool {
@@ -86,7 +67,7 @@ pub struct SessionPool {
86
67
 
87
68
  impl SessionPool {
88
69
  pub fn new(initial: Session, model_path: PathBuf, build_config: ModelConfig) -> Self {
89
- let capacity = pool_capacity(build_config.num_threads);
70
+ let capacity = pool_capacity();
90
71
  Self {
91
72
  sessions: Mutex::new(vec![initial]),
92
73
  available: Condvar::new(),
@@ -235,8 +216,11 @@ pub fn run_session(
235
216
  config: &ModelConfig,
236
217
  ) -> Result<Array2<f32>> {
237
218
  let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
219
+ let run_opts = RunOptions::new()
220
+ .map_err(|e| GteError::Ort(e.to_string()))?
221
+ .with_outputs(OutputSelector::no_default().with(config.output_tensor.as_str()));
238
222
  let outputs = session
239
- .run(input_tensors.inputs)
223
+ .run_with_options(input_tensors.inputs, &run_opts)
240
224
  .map_err(|e| GteError::Ort(e.to_string()))?;
241
225
  let array = extract_output_tensor(&outputs, config.output_tensor.as_str())?;
242
226
 
@@ -282,7 +266,7 @@ mod tests {
282
266
  use ndarray::{array, ArrayView2};
283
267
 
284
268
  use super::{
285
- extract_embeddings, parse_provider_registrations, pool_capacity_with_parallelism,
269
+ extract_embeddings, parse_pool_capacity_override, parse_provider_registrations,
286
270
  resolve_provider_order_with_env,
287
271
  };
288
272
 
@@ -294,7 +278,6 @@ mod tests {
294
278
  mode,
295
279
  with_type_ids: false,
296
280
  with_attention_mask: true,
297
- num_threads: 1,
298
281
  optimization_level: 3,
299
282
  execution_providers: None,
300
283
  }
@@ -343,21 +326,30 @@ mod tests {
343
326
  }
344
327
 
345
328
  #[test]
346
- fn pool_capacity_uses_single_session_for_auto_thread_mode() {
347
- // Auto-thread = ORT uses all cores per session. Pool=1 avoids N²
348
- // intra-op oversubscription under concurrent Ruby threads.
349
- assert_eq!(pool_capacity_with_parallelism(0, 1), 1);
350
- assert_eq!(pool_capacity_with_parallelism(0, 4), 1);
351
- assert_eq!(pool_capacity_with_parallelism(0, 8), 1);
352
- }
329
+ fn parse_pool_capacity_override_uses_positive_integer_only() {
330
+ unsafe {
331
+ std::env::remove_var("GTE_SESSION_POOL_CAP");
332
+ }
333
+ assert_eq!(parse_pool_capacity_override(), None);
353
334
 
354
- #[test]
355
- fn pool_capacity_scales_with_available_parallelism() {
356
- assert_eq!(pool_capacity_with_parallelism(1, 1), 1);
357
- assert_eq!(pool_capacity_with_parallelism(1, 8), 8);
358
- assert_eq!(pool_capacity_with_parallelism(2, 8), 4);
359
- assert_eq!(pool_capacity_with_parallelism(3, 8), 3);
360
- assert_eq!(pool_capacity_with_parallelism(8, 4), 1);
335
+ unsafe {
336
+ std::env::set_var("GTE_SESSION_POOL_CAP", "0");
337
+ }
338
+ assert_eq!(parse_pool_capacity_override(), None);
339
+
340
+ unsafe {
341
+ std::env::set_var("GTE_SESSION_POOL_CAP", "4");
342
+ }
343
+ assert_eq!(parse_pool_capacity_override(), Some(4));
344
+
345
+ unsafe {
346
+ std::env::set_var("GTE_SESSION_POOL_CAP", "abc");
347
+ }
348
+ assert_eq!(parse_pool_capacity_override(), None);
349
+
350
+ unsafe {
351
+ std::env::remove_var("GTE_SESSION_POOL_CAP");
352
+ }
361
353
  }
362
354
 
363
355
  #[test]
@@ -8,7 +8,7 @@ fn model_dir(env_var: &str) -> Option<String> {
8
8
  #[test]
9
9
  fn test_e5_single_embedding_shape() {
10
10
  let Some(dir) = model_dir("GTE_BENCH_E5_DIR") else { return };
11
- let embedder = Embedder::from_dir(&dir, 0, 3, ModelLoadOverrides::default())
11
+ let embedder = Embedder::from_dir(&dir, 0, ModelLoadOverrides::default())
12
12
  .expect("embedder should initialize");
13
13
  let result = embedder
14
14
  .embed(vec!["query: Hello world".to_string()])
@@ -21,7 +21,7 @@ fn test_e5_single_embedding_shape() {
21
21
  #[test]
22
22
  fn test_clip_single_embedding_shape() {
23
23
  let Some(dir) = model_dir("GTE_BENCH_CLIP_DIR") else { return };
24
- let embedder = Embedder::from_dir(&dir, 0, 3, ModelLoadOverrides::default())
24
+ let embedder = Embedder::from_dir(&dir, 0, ModelLoadOverrides::default())
25
25
  .expect("embedder should initialize");
26
26
  let result = embedder
27
27
  .embed(vec!["a photo of a cat".to_string()])
@@ -34,7 +34,7 @@ fn test_clip_single_embedding_shape() {
34
34
  #[test]
35
35
  fn test_e5_batch_embedding_shape() {
36
36
  let Some(dir) = model_dir("GTE_BENCH_E5_DIR") else { return };
37
- let embedder = Embedder::from_dir(&dir, 0, 3, ModelLoadOverrides::default())
37
+ let embedder = Embedder::from_dir(&dir, 0, ModelLoadOverrides::default())
38
38
  .expect("embedder should initialize");
39
39
  let texts = vec![
40
40
  "query: first sentence".to_string(),
@@ -51,7 +51,7 @@ fn test_e5_batch_embedding_shape() {
51
51
  #[test]
52
52
  fn test_e5_long_input_truncation_no_error() {
53
53
  let Some(dir) = model_dir("GTE_BENCH_E5_DIR") else { return };
54
- let embedder = Embedder::from_dir(&dir, 0, 3, ModelLoadOverrides::default())
54
+ let embedder = Embedder::from_dir(&dir, 0, ModelLoadOverrides::default())
55
55
  .expect("embedder should initialize");
56
56
  let very_long_text = "word ".repeat(1000);
57
57
  let result = embedder
data/lib/gte/config.rb CHANGED
@@ -3,12 +3,12 @@
3
3
  module GTE
4
4
  module Config
5
5
  Text = Data.define(
6
- :model_dir, :threads, :optimization_level,
6
+ :model_dir, :optimization_level,
7
7
  :model_name, :normalize, :output_tensor, :max_length, :padding, :execution_providers
8
8
  )
9
9
 
10
10
  Reranker = Data.define(
11
- :model_dir, :threads, :optimization_level,
11
+ :model_dir, :optimization_level,
12
12
  :model_name, :sigmoid, :output_tensor, :max_length, :padding, :execution_providers
13
13
  )
14
14
  end
data/lib/gte/embedder.rb CHANGED
@@ -2,7 +2,6 @@
2
2
 
3
3
  module GTE
4
4
  class Embedder
5
- DEFAULT_THREADS = 0
6
5
  DEFAULT_OPTIMIZATION_LEVEL = 3
7
6
 
8
7
  class << self
@@ -15,7 +14,6 @@ module GTE
15
14
  def from_config(config)
16
15
  new(
17
16
  config.model_dir,
18
- config.threads,
19
17
  config.optimization_level,
20
18
  config.model_name.to_s,
21
19
  config.normalize,
@@ -29,7 +27,6 @@ module GTE
29
27
  def default_config(model_dir)
30
28
  Config::Text.new(
31
29
  model_dir: File.expand_path(model_dir),
32
- threads: DEFAULT_THREADS,
33
30
  optimization_level: DEFAULT_OPTIMIZATION_LEVEL,
34
31
  model_name: nil,
35
32
  normalize: true,
data/lib/gte/reranker.rb CHANGED
@@ -19,7 +19,6 @@ module GTE
19
19
  def default_config(model_dir)
20
20
  Config::Reranker.new(
21
21
  model_dir: File.expand_path(model_dir),
22
- threads: 1,
23
22
  optimization_level: 3,
24
23
  model_name: nil,
25
24
  sigmoid: false,
@@ -33,7 +32,6 @@ module GTE
33
32
  def build(cfg)
34
33
  new(
35
34
  cfg.model_dir,
36
- cfg.threads,
37
35
  cfg.optimization_level,
38
36
  cfg.model_name.to_s,
39
37
  cfg.sigmoid,
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: gte
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.11
4
+ version: 0.0.13
5
5
  platform: ruby
6
6
  authors:
7
7
  - elcuervo