gte 0.0.12 → 0.0.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,221 +4,151 @@ use crate::pipeline::{extract_output_tensor, InputTensors};
4
4
  use crate::postprocess::mean_pool;
5
5
  use crate::tokenizer::Tokenized;
6
6
  use ndarray::{Array2, ArrayView2, ArrayViewD, Ix2};
7
- use ort::execution_providers::{
8
- CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider,
9
- };
10
- use ort::session::Session;
7
+ use ort::execution_providers::{CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider};
8
+ use ort::session::{OutputSelector, RunOptions, Session};
9
+ use std::cell::RefCell;
10
+ use std::collections::hash_map::Entry;
11
+ use std::collections::HashMap;
11
12
  use std::path::{Path, PathBuf};
12
13
  use std::sync::atomic::{AtomicUsize, Ordering};
13
- use std::sync::{Condvar, Mutex};
14
-
15
- pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
16
- let opt_level = match config.optimization_level {
17
- 0 => ort::session::builder::GraphOptimizationLevel::Disable,
18
- 1 => ort::session::builder::GraphOptimizationLevel::Level1,
19
- 2 => ort::session::builder::GraphOptimizationLevel::Level2,
20
- _ => ort::session::builder::GraphOptimizationLevel::Level3,
21
- };
22
-
23
- fn ort_err(e: impl std::fmt::Display) -> GteError {
24
- GteError::Ort(e.to_string())
25
- }
26
-
27
- let mut builder = Session::builder()
28
- .map_err(ort_err)?
29
- .with_optimization_level(opt_level)
30
- .map_err(ort_err)?;
31
-
32
- let providers = preferred_execution_providers(config.execution_providers.as_deref());
33
- if !providers.is_empty() {
34
- builder = builder
35
- .with_execution_providers(providers)
36
- .map_err(ort_err)?;
37
- }
38
-
39
- builder.commit_from_file(model_path).map_err(ort_err)
40
- }
41
14
 
42
15
  // ---------------------------------------------------------------------------
43
- // Session pool
16
+ // Thread-local session storage — each OS thread lazily creates its own ONNX
17
+ // session the first time it calls into a given pool. No Mutex, no contention.
44
18
  // ---------------------------------------------------------------------------
45
19
 
46
- fn pool_capacity() -> usize {
47
- let available = std::thread::available_parallelism()
48
- .map(|n| n.get())
49
- .unwrap_or(1);
50
- parse_pool_capacity_override().map_or(available, |cap| cap.min(available).max(1))
20
+ static NEXT_POOL_ID: AtomicUsize = AtomicUsize::new(1);
21
+
22
+ struct SessionRecipe {
23
+ model_path: PathBuf,
24
+ build_config: ModelConfig,
51
25
  }
52
26
 
53
- fn parse_pool_capacity_override() -> Option<usize> {
54
- let raw = std::env::var("GTE_SESSION_POOL_CAP").ok()?;
55
- let parsed = raw.trim().parse::<usize>().ok()?;
56
- (parsed > 0).then_some(parsed)
27
+ thread_local! {
28
+ static SESSIONS: RefCell<HashMap<usize, Session>> = RefCell::new(HashMap::new());
57
29
  }
58
30
 
59
31
  pub struct SessionPool {
60
- sessions: Mutex<Vec<Session>>,
61
- available: Condvar,
62
- created: AtomicUsize,
63
- capacity: usize,
64
- model_path: PathBuf,
65
- build_config: ModelConfig,
32
+ pool_id: usize,
33
+ recipe: SessionRecipe,
66
34
  }
67
35
 
68
36
  impl SessionPool {
69
- pub fn new(initial: Session, model_path: PathBuf, build_config: ModelConfig) -> Self {
70
- let capacity = pool_capacity();
71
- Self {
72
- sessions: Mutex::new(vec![initial]),
73
- available: Condvar::new(),
74
- created: AtomicUsize::new(1),
75
- capacity,
76
- model_path,
77
- build_config,
78
- }
79
- }
80
-
81
- pub fn acquire(&self) -> Result<PooledSession<'_>> {
82
- if let Some(session) = self.take_available_session() {
83
- return Ok(PooledSession {
84
- pool: self,
85
- session: Some(session),
86
- });
87
- }
37
+ pub fn new(initial: Session, model_path: &Path, build_config: &ModelConfig) -> Result<Self> {
38
+ let pool_id = NEXT_POOL_ID.fetch_add(1, Ordering::Relaxed);
88
39
 
89
- if let Some(session) = self.try_grow()? {
90
- return Ok(PooledSession {
91
- pool: self,
92
- session: Some(session),
93
- });
94
- }
40
+ SESSIONS.with(|map| {
41
+ _ = map.borrow_mut().insert(pool_id, initial);
42
+ });
95
43
 
96
- let session = self.wait_for_session();
97
- Ok(PooledSession {
98
- pool: self,
99
- session: Some(session),
44
+ Ok(Self {
45
+ pool_id,
46
+ recipe: SessionRecipe { model_path: model_path.to_path_buf(), build_config: build_config.clone() },
100
47
  })
101
48
  }
102
49
 
103
- fn release(&self, session: Session) {
104
- self.sessions.lock().unwrap().push(session);
105
- self.available.notify_one();
50
+ pub fn run(&self, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
51
+ self.with_session(|session| run_session(session, tokenized, config))
106
52
  }
107
53
 
108
- fn take_available_session(&self) -> Option<Session> {
109
- self.sessions.lock().unwrap().pop()
54
+ pub fn with_session<F, R>(&self, f: F) -> Result<R>
55
+ where
56
+ F: FnOnce(&mut Session) -> Result<R>,
57
+ {
58
+ SESSIONS.with(|map| {
59
+ let mut map = map.borrow_mut();
60
+ let session = match map.entry(self.pool_id) {
61
+ Entry::Occupied(e) => e.into_mut(),
62
+ Entry::Vacant(e) => {
63
+ let session = build_session(&self.recipe.model_path, &self.recipe.build_config)?;
64
+ e.insert(session)
65
+ }
66
+ };
67
+ f(session)
68
+ })
110
69
  }
70
+ }
111
71
 
112
- fn try_grow(&self) -> Result<Option<Session>> {
113
- let grew = self
114
- .created
115
- .fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
116
- (count < self.capacity).then_some(count + 1)
117
- });
118
- if grew.is_err() {
119
- return Ok(None);
120
- }
72
+ // ---------------------------------------------------------------------------
73
+ // Session construction
74
+ // ---------------------------------------------------------------------------
121
75
 
122
- match build_session(&self.model_path, &self.build_config) {
123
- Ok(session) => Ok(Some(session)),
124
- Err(error) => {
125
- self.created.fetch_sub(1, Ordering::AcqRel);
126
- Err(error)
127
- }
128
- }
76
+ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
77
+ fn ort_err(e: impl std::fmt::Display) -> GteError {
78
+ GteError::Ort(e.to_string())
129
79
  }
130
80
 
131
- fn wait_for_session(&self) -> Session {
132
- let mut lock = self.sessions.lock().unwrap();
133
- loop {
134
- if let Some(session) = lock.pop() {
135
- return session;
136
- }
137
- lock = self.available.wait(lock).unwrap();
138
- }
139
- }
140
- }
81
+ let opt_level = match config.optimization_level {
82
+ 0 => ort::session::builder::GraphOptimizationLevel::Disable,
83
+ 1 => ort::session::builder::GraphOptimizationLevel::Level1,
84
+ 2 => ort::session::builder::GraphOptimizationLevel::Level2,
85
+ _ => ort::session::builder::GraphOptimizationLevel::Level3,
86
+ };
141
87
 
142
- pub struct PooledSession<'a> {
143
- pool: &'a SessionPool,
144
- session: Option<Session>,
145
- }
88
+ let mut builder = Session::builder().map_err(ort_err)?.with_optimization_level(opt_level).map_err(ort_err)?;
146
89
 
147
- impl std::ops::Deref for PooledSession<'_> {
148
- type Target = Session;
149
- fn deref(&self) -> &Session {
150
- self.session.as_ref().unwrap()
151
- }
152
- }
90
+ let intra_threads = std::env::var("GTE_INTRA_OP_NUM_THREADS")
91
+ .ok()
92
+ .and_then(|v| v.trim().parse::<usize>().ok())
93
+ .unwrap_or_else(|| std::thread::available_parallelism().map(|n| n.get().min(4)).unwrap_or(1));
94
+ builder = builder.with_intra_threads(intra_threads).map_err(ort_err)?;
153
95
 
154
- impl std::ops::DerefMut for PooledSession<'_> {
155
- fn deref_mut(&mut self) -> &mut Session {
156
- self.session.as_mut().unwrap()
157
- }
158
- }
96
+ let inter_threads =
97
+ std::env::var("GTE_INTER_OP_NUM_THREADS").ok().and_then(|v| v.trim().parse::<usize>().ok()).unwrap_or(1);
98
+ builder = builder.with_inter_threads(inter_threads).map_err(ort_err)?;
159
99
 
160
- impl Drop for PooledSession<'_> {
161
- fn drop(&mut self) {
162
- if let Some(s) = self.session.take() {
163
- self.pool.release(s);
164
- }
100
+ let providers = match config.execution_providers.as_deref() {
101
+ Some(override_val) => preferred_execution_providers(Some(override_val)),
102
+ None => auto_detect_providers(),
103
+ };
104
+ if !providers.is_empty() {
105
+ builder = builder.with_execution_providers(providers).map_err(ort_err)?;
165
106
  }
166
- }
167
-
168
- // ---------------------------------------------------------------------------
169
107
 
170
- fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
171
- let order = resolve_provider_order(order_override);
108
+ builder.commit_from_file(model_path).map_err(ort_err)
109
+ }
172
110
 
111
+ fn auto_detect_providers() -> Vec<ExecutionProviderDispatch> {
173
112
  let mut providers = Vec::new();
174
- for provider in parse_provider_registrations(order.as_str()) {
175
- match provider {
176
- "xnnpack" => {
177
- providers.push(XNNPACKExecutionProvider::default().build().fail_silently())
178
- }
179
- "coreml" => providers.push(CoreMLExecutionProvider::default().build().fail_silently()),
180
- _ => {}
181
- }
182
- }
113
+ #[cfg(target_arch = "aarch64")]
114
+ providers.push(XNNPACKExecutionProvider::default().build().fail_silently());
183
115
  providers
184
116
  }
185
117
 
186
- fn resolve_provider_order(order_override: Option<&str>) -> String {
187
- let env_order = std::env::var("GTE_EXECUTION_PROVIDERS").ok();
188
- resolve_provider_order_with_env(order_override, env_order.as_deref())
189
- }
190
-
191
- fn resolve_provider_order_with_env(
192
- order_override: Option<&str>,
193
- env_order: Option<&str>,
194
- ) -> String {
195
- order_override
196
- .or(env_order)
197
- .unwrap_or("cpu")
198
- .to_ascii_lowercase()
199
- }
118
+ fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
119
+ let order = match order_override {
120
+ Some(s) => s.to_ascii_lowercase(),
121
+ None => return auto_detect_providers(),
122
+ };
200
123
 
201
- fn parse_provider_registrations(order: &str) -> Vec<&str> {
202
- let mut providers = Vec::new();
203
- for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
204
- match provider {
205
- "xnnpack" | "coreml" => providers.push(provider),
206
- "none" | "cpu" => {}
207
- _ => {}
208
- }
124
+ if order.is_empty() || order == "cpu" || order == "none" {
125
+ return Vec::new();
209
126
  }
127
+
128
+ let providers: Vec<_> = order
129
+ .split(',')
130
+ .map(str::trim)
131
+ .filter(|p| !p.is_empty())
132
+ .filter_map(|provider| match provider {
133
+ "xnnpack" => Some(XNNPACKExecutionProvider::default().build().fail_silently()),
134
+ "coreml" => Some(CoreMLExecutionProvider::default().build().fail_silently()),
135
+ _ => None,
136
+ })
137
+ .collect();
210
138
  providers
211
139
  }
212
140
 
213
- pub fn run_session(
214
- session: &mut Session,
215
- tokenized: &Tokenized,
216
- config: &ModelConfig,
217
- ) -> Result<Array2<f32>> {
141
+ // ---------------------------------------------------------------------------
142
+ // Run a single inference
143
+ // ---------------------------------------------------------------------------
144
+
145
+ pub fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
218
146
  let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
219
- let outputs = session
220
- .run(input_tensors.inputs)
221
- .map_err(|e| GteError::Ort(e.to_string()))?;
147
+ let run_opts = RunOptions::new()
148
+ .map_err(|e| GteError::Ort(e.to_string()))?
149
+ .with_outputs(OutputSelector::no_default().with(config.output_tensor.as_str()));
150
+ let outputs =
151
+ session.run_with_options(input_tensors.inputs, &run_opts).map_err(|e| GteError::Ort(e.to_string()))?;
222
152
  let array = extract_output_tensor(&outputs, config.output_tensor.as_str())?;
223
153
 
224
154
  extract_embeddings(array, input_tensors.attention_mask, config)
@@ -234,26 +164,21 @@ fn extract_embeddings(
234
164
  let shape = array.shape();
235
165
  if shape.len() != 3 || idx >= shape[1] {
236
166
  return Err(GteError::Inference(format!(
237
- "token extraction index {} out of bounds for output shape {:?}",
238
- idx, shape
167
+ "token extraction index {idx} out of bounds for output shape {shape:?}"
239
168
  )));
240
169
  }
241
170
  Ok(array.slice(ndarray::s![.., idx, ..]).into_owned())
242
171
  }
243
172
  ExtractorMode::MeanPool => {
244
173
  let ndim = array.ndim();
245
- let hidden_states = array.into_dimensionality::<ndarray::Ix3>().map_err(|_| {
246
- GteError::Inference(format!(
247
- "mean pooling requires rank-3 output, got rank {}",
248
- ndim
249
- ))
250
- })?;
174
+ let hidden_states = array
175
+ .into_dimensionality::<ndarray::Ix3>()
176
+ .map_err(|_| GteError::Inference(format!("mean pooling requires rank-3 output, got rank {ndim}")))?;
251
177
  mean_pool(hidden_states, attention_mask)
252
178
  }
253
- ExtractorMode::Raw => array
254
- .into_dimensionality::<Ix2>()
255
- .map(|view| view.to_owned())
256
- .map_err(|e| GteError::Shape(e.to_string())),
179
+ ExtractorMode::Raw => {
180
+ array.into_dimensionality::<Ix2>().map(|view| view.to_owned()).map_err(|e| GteError::Shape(e.to_string()))
181
+ }
257
182
  }
258
183
  }
259
184
 
@@ -262,10 +187,22 @@ mod tests {
262
187
  use crate::model_config::{ExtractorMode, ModelConfig, PaddingMode};
263
188
  use ndarray::{array, ArrayView2};
264
189
 
265
- use super::{
266
- extract_embeddings, parse_pool_capacity_override, parse_provider_registrations,
267
- resolve_provider_order_with_env,
268
- };
190
+ use super::extract_embeddings;
191
+
192
+ fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
193
+ order_override.or(env_order).unwrap_or("cpu").to_ascii_lowercase()
194
+ }
195
+
196
+ fn parse_provider_registrations(order: &str) -> Vec<&str> {
197
+ let mut providers = Vec::new();
198
+ for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
199
+ match provider {
200
+ "xnnpack" | "coreml" => providers.push(provider),
201
+ _ => {}
202
+ }
203
+ }
204
+ providers
205
+ }
269
206
 
270
207
  fn test_config(mode: ExtractorMode) -> ModelConfig {
271
208
  ModelConfig {
@@ -277,6 +214,8 @@ mod tests {
277
214
  with_attention_mask: true,
278
215
  optimization_level: 3,
279
216
  execution_providers: None,
217
+ lowercase_input: false,
218
+ max_input_chars: None,
280
219
  }
281
220
  }
282
221
 
@@ -306,93 +245,45 @@ mod tests {
306
245
 
307
246
  #[test]
308
247
  fn resolve_provider_order_prefers_override() {
309
- assert_eq!(
310
- resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")),
311
- "xnnpack"
312
- );
248
+ assert_eq!(resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")), "xnnpack");
313
249
  assert_eq!(resolve_provider_order_with_env(Some("CPU"), None), "cpu");
314
250
  }
315
251
 
316
252
  #[test]
317
253
  fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
318
- assert_eq!(
319
- resolve_provider_order_with_env(None, Some("coreml")),
320
- "coreml"
321
- );
254
+ assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
322
255
  assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
323
256
  }
324
257
 
325
- #[test]
326
- fn parse_pool_capacity_override_uses_positive_integer_only() {
327
- unsafe {
328
- std::env::remove_var("GTE_SESSION_POOL_CAP");
329
- }
330
- assert_eq!(parse_pool_capacity_override(), None);
331
-
332
- unsafe {
333
- std::env::set_var("GTE_SESSION_POOL_CAP", "0");
334
- }
335
- assert_eq!(parse_pool_capacity_override(), None);
336
-
337
- unsafe {
338
- std::env::set_var("GTE_SESSION_POOL_CAP", "4");
339
- }
340
- assert_eq!(parse_pool_capacity_override(), Some(4));
341
-
342
- unsafe {
343
- std::env::set_var("GTE_SESSION_POOL_CAP", "abc");
344
- }
345
- assert_eq!(parse_pool_capacity_override(), None);
346
-
347
- unsafe {
348
- std::env::remove_var("GTE_SESSION_POOL_CAP");
349
- }
350
- }
351
-
352
258
  #[test]
353
259
  fn extract_embeddings_raw_copies_only_final_matrix() {
354
260
  let output = array![[1.0f32, 2.0], [3.0, 4.0]];
355
- let extracted = extract_embeddings(
356
- output.view().into_dyn(),
357
- empty_attention_mask(),
358
- &test_config(ExtractorMode::Raw),
359
- )
360
- .unwrap();
261
+ let extracted =
262
+ extract_embeddings(output.view().into_dyn(), empty_attention_mask(), &test_config(ExtractorMode::Raw))
263
+ .unwrap();
361
264
 
362
265
  assert_eq!(extracted, output);
363
266
  }
364
267
 
365
268
  #[test]
366
269
  fn extract_embeddings_token_selects_without_copying_full_sequence() {
367
- let output = array![
368
- [[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
369
- [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
370
- ];
270
+ let output = array![[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]];
371
271
  let expected = array![[3.0f32, 4.0], [9.0, 10.0]];
372
- let extracted = extract_embeddings(
373
- output.view().into_dyn(),
374
- empty_attention_mask(),
375
- &test_config(ExtractorMode::Token(1)),
376
- )
377
- .unwrap();
272
+ let extracted =
273
+ extract_embeddings(output.view().into_dyn(), empty_attention_mask(), &test_config(ExtractorMode::Token(1)))
274
+ .unwrap();
378
275
 
379
276
  assert_eq!(extracted, expected);
380
277
  }
381
278
 
382
279
  #[test]
383
280
  fn extract_embeddings_mean_pool_uses_output_view_and_attention_mask() {
384
- let output = array![
385
- [[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]],
386
- [[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]
387
- ];
281
+ let output = array![[[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]], [[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]];
388
282
  let attention_mask = array![[1_i64, 1, 0], [0, 1, 1]];
389
283
  let expected = array![[3.0f32, 5.0], [8.0, 10.0]];
390
- let extracted = extract_embeddings(
391
- output.view().into_dyn(),
392
- attention_mask.view(),
393
- &test_config(ExtractorMode::MeanPool),
394
- )
395
- .unwrap();
284
+ let extracted =
285
+ extract_embeddings(output.view().into_dyn(), attention_mask.view(), &test_config(ExtractorMode::MeanPool))
286
+ .unwrap();
396
287
 
397
288
  assert_eq!(extracted, expected);
398
289
  }