gte 0.0.13 → 0.0.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,224 +4,151 @@ use crate::pipeline::{extract_output_tensor, InputTensors};
4
4
  use crate::postprocess::mean_pool;
5
5
  use crate::tokenizer::Tokenized;
6
6
  use ndarray::{Array2, ArrayView2, ArrayViewD, Ix2};
7
- use ort::execution_providers::{
8
- CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider,
9
- };
7
+ use ort::execution_providers::{CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider};
10
8
  use ort::session::{OutputSelector, RunOptions, Session};
9
+ use std::cell::RefCell;
10
+ use std::collections::hash_map::Entry;
11
+ use std::collections::HashMap;
11
12
  use std::path::{Path, PathBuf};
12
13
  use std::sync::atomic::{AtomicUsize, Ordering};
13
- use std::sync::{Condvar, Mutex};
14
-
15
- pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
16
- let opt_level = match config.optimization_level {
17
- 0 => ort::session::builder::GraphOptimizationLevel::Disable,
18
- 1 => ort::session::builder::GraphOptimizationLevel::Level1,
19
- 2 => ort::session::builder::GraphOptimizationLevel::Level2,
20
- _ => ort::session::builder::GraphOptimizationLevel::Level3,
21
- };
22
-
23
- fn ort_err(e: impl std::fmt::Display) -> GteError {
24
- GteError::Ort(e.to_string())
25
- }
26
-
27
- let mut builder = Session::builder()
28
- .map_err(ort_err)?
29
- .with_optimization_level(opt_level)
30
- .map_err(ort_err)?;
31
-
32
- let providers = preferred_execution_providers(config.execution_providers.as_deref());
33
- if !providers.is_empty() {
34
- builder = builder
35
- .with_execution_providers(providers)
36
- .map_err(ort_err)?;
37
- }
38
-
39
- builder.commit_from_file(model_path).map_err(ort_err)
40
- }
41
14
 
42
15
  // ---------------------------------------------------------------------------
43
- // Session pool
16
+ // Thread-local session storage — each OS thread lazily creates its own ONNX
17
+ // session the first time it calls into a given pool. No Mutex, no contention.
44
18
  // ---------------------------------------------------------------------------
45
19
 
46
- fn pool_capacity() -> usize {
47
- let available = std::thread::available_parallelism()
48
- .map(|n| n.get())
49
- .unwrap_or(1);
50
- parse_pool_capacity_override().map_or(available, |cap| cap.min(available).max(1))
20
+ static NEXT_POOL_ID: AtomicUsize = AtomicUsize::new(1);
21
+
22
+ struct SessionRecipe {
23
+ model_path: PathBuf,
24
+ build_config: ModelConfig,
51
25
  }
52
26
 
53
- fn parse_pool_capacity_override() -> Option<usize> {
54
- let raw = std::env::var("GTE_SESSION_POOL_CAP").ok()?;
55
- let parsed = raw.trim().parse::<usize>().ok()?;
56
- (parsed > 0).then_some(parsed)
27
+ thread_local! {
28
+ static SESSIONS: RefCell<HashMap<usize, Session>> = RefCell::new(HashMap::new());
57
29
  }
58
30
 
59
31
  pub struct SessionPool {
60
- sessions: Mutex<Vec<Session>>,
61
- available: Condvar,
62
- created: AtomicUsize,
63
- capacity: usize,
64
- model_path: PathBuf,
65
- build_config: ModelConfig,
32
+ pool_id: usize,
33
+ recipe: SessionRecipe,
66
34
  }
67
35
 
68
36
  impl SessionPool {
69
- pub fn new(initial: Session, model_path: PathBuf, build_config: ModelConfig) -> Self {
70
- let capacity = pool_capacity();
71
- Self {
72
- sessions: Mutex::new(vec![initial]),
73
- available: Condvar::new(),
74
- created: AtomicUsize::new(1),
75
- capacity,
76
- model_path,
77
- build_config,
78
- }
79
- }
80
-
81
- pub fn acquire(&self) -> Result<PooledSession<'_>> {
82
- if let Some(session) = self.take_available_session() {
83
- return Ok(PooledSession {
84
- pool: self,
85
- session: Some(session),
86
- });
87
- }
37
+ pub fn new(initial: Session, model_path: &Path, build_config: &ModelConfig) -> Result<Self> {
38
+ let pool_id = NEXT_POOL_ID.fetch_add(1, Ordering::Relaxed);
88
39
 
89
- if let Some(session) = self.try_grow()? {
90
- return Ok(PooledSession {
91
- pool: self,
92
- session: Some(session),
93
- });
94
- }
40
+ SESSIONS.with(|map| {
41
+ _ = map.borrow_mut().insert(pool_id, initial);
42
+ });
95
43
 
96
- let session = self.wait_for_session();
97
- Ok(PooledSession {
98
- pool: self,
99
- session: Some(session),
44
+ Ok(Self {
45
+ pool_id,
46
+ recipe: SessionRecipe { model_path: model_path.to_path_buf(), build_config: build_config.clone() },
100
47
  })
101
48
  }
102
49
 
103
- fn release(&self, session: Session) {
104
- self.sessions.lock().unwrap().push(session);
105
- self.available.notify_one();
50
+ pub fn run(&self, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
51
+ self.with_session(|session| run_session(session, tokenized, config))
106
52
  }
107
53
 
108
- fn take_available_session(&self) -> Option<Session> {
109
- self.sessions.lock().unwrap().pop()
54
+ pub fn with_session<F, R>(&self, f: F) -> Result<R>
55
+ where
56
+ F: FnOnce(&mut Session) -> Result<R>,
57
+ {
58
+ SESSIONS.with(|map| {
59
+ let mut map = map.borrow_mut();
60
+ let session = match map.entry(self.pool_id) {
61
+ Entry::Occupied(e) => e.into_mut(),
62
+ Entry::Vacant(e) => {
63
+ let session = build_session(&self.recipe.model_path, &self.recipe.build_config)?;
64
+ e.insert(session)
65
+ }
66
+ };
67
+ f(session)
68
+ })
110
69
  }
70
+ }
111
71
 
112
- fn try_grow(&self) -> Result<Option<Session>> {
113
- let grew = self
114
- .created
115
- .fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
116
- (count < self.capacity).then_some(count + 1)
117
- });
118
- if grew.is_err() {
119
- return Ok(None);
120
- }
72
+ // ---------------------------------------------------------------------------
73
+ // Session construction
74
+ // ---------------------------------------------------------------------------
121
75
 
122
- match build_session(&self.model_path, &self.build_config) {
123
- Ok(session) => Ok(Some(session)),
124
- Err(error) => {
125
- self.created.fetch_sub(1, Ordering::AcqRel);
126
- Err(error)
127
- }
128
- }
76
+ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
77
+ fn ort_err(e: impl std::fmt::Display) -> GteError {
78
+ GteError::Ort(e.to_string())
129
79
  }
130
80
 
131
- fn wait_for_session(&self) -> Session {
132
- let mut lock = self.sessions.lock().unwrap();
133
- loop {
134
- if let Some(session) = lock.pop() {
135
- return session;
136
- }
137
- lock = self.available.wait(lock).unwrap();
138
- }
139
- }
140
- }
81
+ let opt_level = match config.optimization_level {
82
+ 0 => ort::session::builder::GraphOptimizationLevel::Disable,
83
+ 1 => ort::session::builder::GraphOptimizationLevel::Level1,
84
+ 2 => ort::session::builder::GraphOptimizationLevel::Level2,
85
+ _ => ort::session::builder::GraphOptimizationLevel::Level3,
86
+ };
141
87
 
142
- pub struct PooledSession<'a> {
143
- pool: &'a SessionPool,
144
- session: Option<Session>,
145
- }
88
+ let mut builder = Session::builder().map_err(ort_err)?.with_optimization_level(opt_level).map_err(ort_err)?;
146
89
 
147
- impl std::ops::Deref for PooledSession<'_> {
148
- type Target = Session;
149
- fn deref(&self) -> &Session {
150
- self.session.as_ref().unwrap()
151
- }
152
- }
90
+ let intra_threads = std::env::var("GTE_INTRA_OP_NUM_THREADS")
91
+ .ok()
92
+ .and_then(|v| v.trim().parse::<usize>().ok())
93
+ .unwrap_or_else(|| std::thread::available_parallelism().map(|n| n.get().min(4)).unwrap_or(1));
94
+ builder = builder.with_intra_threads(intra_threads).map_err(ort_err)?;
153
95
 
154
- impl std::ops::DerefMut for PooledSession<'_> {
155
- fn deref_mut(&mut self) -> &mut Session {
156
- self.session.as_mut().unwrap()
157
- }
158
- }
96
+ let inter_threads =
97
+ std::env::var("GTE_INTER_OP_NUM_THREADS").ok().and_then(|v| v.trim().parse::<usize>().ok()).unwrap_or(1);
98
+ builder = builder.with_inter_threads(inter_threads).map_err(ort_err)?;
159
99
 
160
- impl Drop for PooledSession<'_> {
161
- fn drop(&mut self) {
162
- if let Some(s) = self.session.take() {
163
- self.pool.release(s);
164
- }
100
+ let providers = match config.execution_providers.as_deref() {
101
+ Some(override_val) => preferred_execution_providers(Some(override_val)),
102
+ None => auto_detect_providers(),
103
+ };
104
+ if !providers.is_empty() {
105
+ builder = builder.with_execution_providers(providers).map_err(ort_err)?;
165
106
  }
166
- }
167
-
168
- // ---------------------------------------------------------------------------
169
107
 
170
- fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
171
- let order = resolve_provider_order(order_override);
108
+ builder.commit_from_file(model_path).map_err(ort_err)
109
+ }
172
110
 
111
+ fn auto_detect_providers() -> Vec<ExecutionProviderDispatch> {
173
112
  let mut providers = Vec::new();
174
- for provider in parse_provider_registrations(order.as_str()) {
175
- match provider {
176
- "xnnpack" => {
177
- providers.push(XNNPACKExecutionProvider::default().build().fail_silently())
178
- }
179
- "coreml" => providers.push(CoreMLExecutionProvider::default().build().fail_silently()),
180
- _ => {}
181
- }
182
- }
113
+ #[cfg(target_arch = "aarch64")]
114
+ providers.push(XNNPACKExecutionProvider::default().build().fail_silently());
183
115
  providers
184
116
  }
185
117
 
186
- fn resolve_provider_order(order_override: Option<&str>) -> String {
187
- let env_order = std::env::var("GTE_EXECUTION_PROVIDERS").ok();
188
- resolve_provider_order_with_env(order_override, env_order.as_deref())
189
- }
190
-
191
- fn resolve_provider_order_with_env(
192
- order_override: Option<&str>,
193
- env_order: Option<&str>,
194
- ) -> String {
195
- order_override
196
- .or(env_order)
197
- .unwrap_or("cpu")
198
- .to_ascii_lowercase()
199
- }
118
+ fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
119
+ let order = match order_override {
120
+ Some(s) => s.to_ascii_lowercase(),
121
+ None => return auto_detect_providers(),
122
+ };
200
123
 
201
- fn parse_provider_registrations(order: &str) -> Vec<&str> {
202
- let mut providers = Vec::new();
203
- for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
204
- match provider {
205
- "xnnpack" | "coreml" => providers.push(provider),
206
- "none" | "cpu" => {}
207
- _ => {}
208
- }
124
+ if order.is_empty() || order == "cpu" || order == "none" {
125
+ return Vec::new();
209
126
  }
127
+
128
+ let providers: Vec<_> = order
129
+ .split(',')
130
+ .map(str::trim)
131
+ .filter(|p| !p.is_empty())
132
+ .filter_map(|provider| match provider {
133
+ "xnnpack" => Some(XNNPACKExecutionProvider::default().build().fail_silently()),
134
+ "coreml" => Some(CoreMLExecutionProvider::default().build().fail_silently()),
135
+ _ => None,
136
+ })
137
+ .collect();
210
138
  providers
211
139
  }
212
140
 
213
- pub fn run_session(
214
- session: &mut Session,
215
- tokenized: &Tokenized,
216
- config: &ModelConfig,
217
- ) -> Result<Array2<f32>> {
141
+ // ---------------------------------------------------------------------------
142
+ // Run a single inference
143
+ // ---------------------------------------------------------------------------
144
+
145
+ pub fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
218
146
  let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
219
147
  let run_opts = RunOptions::new()
220
148
  .map_err(|e| GteError::Ort(e.to_string()))?
221
149
  .with_outputs(OutputSelector::no_default().with(config.output_tensor.as_str()));
222
- let outputs = session
223
- .run_with_options(input_tensors.inputs, &run_opts)
224
- .map_err(|e| GteError::Ort(e.to_string()))?;
150
+ let outputs =
151
+ session.run_with_options(input_tensors.inputs, &run_opts).map_err(|e| GteError::Ort(e.to_string()))?;
225
152
  let array = extract_output_tensor(&outputs, config.output_tensor.as_str())?;
226
153
 
227
154
  extract_embeddings(array, input_tensors.attention_mask, config)
@@ -237,26 +164,21 @@ fn extract_embeddings(
237
164
  let shape = array.shape();
238
165
  if shape.len() != 3 || idx >= shape[1] {
239
166
  return Err(GteError::Inference(format!(
240
- "token extraction index {} out of bounds for output shape {:?}",
241
- idx, shape
167
+ "token extraction index {idx} out of bounds for output shape {shape:?}"
242
168
  )));
243
169
  }
244
170
  Ok(array.slice(ndarray::s![.., idx, ..]).into_owned())
245
171
  }
246
172
  ExtractorMode::MeanPool => {
247
173
  let ndim = array.ndim();
248
- let hidden_states = array.into_dimensionality::<ndarray::Ix3>().map_err(|_| {
249
- GteError::Inference(format!(
250
- "mean pooling requires rank-3 output, got rank {}",
251
- ndim
252
- ))
253
- })?;
174
+ let hidden_states = array
175
+ .into_dimensionality::<ndarray::Ix3>()
176
+ .map_err(|_| GteError::Inference(format!("mean pooling requires rank-3 output, got rank {ndim}")))?;
254
177
  mean_pool(hidden_states, attention_mask)
255
178
  }
256
- ExtractorMode::Raw => array
257
- .into_dimensionality::<Ix2>()
258
- .map(|view| view.to_owned())
259
- .map_err(|e| GteError::Shape(e.to_string())),
179
+ ExtractorMode::Raw => {
180
+ array.into_dimensionality::<Ix2>().map(|view| view.to_owned()).map_err(|e| GteError::Shape(e.to_string()))
181
+ }
260
182
  }
261
183
  }
262
184
 
@@ -265,10 +187,22 @@ mod tests {
265
187
  use crate::model_config::{ExtractorMode, ModelConfig, PaddingMode};
266
188
  use ndarray::{array, ArrayView2};
267
189
 
268
- use super::{
269
- extract_embeddings, parse_pool_capacity_override, parse_provider_registrations,
270
- resolve_provider_order_with_env,
271
- };
190
+ use super::extract_embeddings;
191
+
192
+ fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
193
+ order_override.or(env_order).unwrap_or("cpu").to_ascii_lowercase()
194
+ }
195
+
196
+ fn parse_provider_registrations(order: &str) -> Vec<&str> {
197
+ let mut providers = Vec::new();
198
+ for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
199
+ match provider {
200
+ "xnnpack" | "coreml" => providers.push(provider),
201
+ _ => {}
202
+ }
203
+ }
204
+ providers
205
+ }
272
206
 
273
207
  fn test_config(mode: ExtractorMode) -> ModelConfig {
274
208
  ModelConfig {
@@ -280,6 +214,8 @@ mod tests {
280
214
  with_attention_mask: true,
281
215
  optimization_level: 3,
282
216
  execution_providers: None,
217
+ lowercase_input: false,
218
+ max_input_chars: None,
283
219
  }
284
220
  }
285
221
 
@@ -309,93 +245,45 @@ mod tests {
309
245
 
310
246
  #[test]
311
247
  fn resolve_provider_order_prefers_override() {
312
- assert_eq!(
313
- resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")),
314
- "xnnpack"
315
- );
248
+ assert_eq!(resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")), "xnnpack");
316
249
  assert_eq!(resolve_provider_order_with_env(Some("CPU"), None), "cpu");
317
250
  }
318
251
 
319
252
  #[test]
320
253
  fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
321
- assert_eq!(
322
- resolve_provider_order_with_env(None, Some("coreml")),
323
- "coreml"
324
- );
254
+ assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
325
255
  assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
326
256
  }
327
257
 
328
- #[test]
329
- fn parse_pool_capacity_override_uses_positive_integer_only() {
330
- unsafe {
331
- std::env::remove_var("GTE_SESSION_POOL_CAP");
332
- }
333
- assert_eq!(parse_pool_capacity_override(), None);
334
-
335
- unsafe {
336
- std::env::set_var("GTE_SESSION_POOL_CAP", "0");
337
- }
338
- assert_eq!(parse_pool_capacity_override(), None);
339
-
340
- unsafe {
341
- std::env::set_var("GTE_SESSION_POOL_CAP", "4");
342
- }
343
- assert_eq!(parse_pool_capacity_override(), Some(4));
344
-
345
- unsafe {
346
- std::env::set_var("GTE_SESSION_POOL_CAP", "abc");
347
- }
348
- assert_eq!(parse_pool_capacity_override(), None);
349
-
350
- unsafe {
351
- std::env::remove_var("GTE_SESSION_POOL_CAP");
352
- }
353
- }
354
-
355
258
  #[test]
356
259
  fn extract_embeddings_raw_copies_only_final_matrix() {
357
260
  let output = array![[1.0f32, 2.0], [3.0, 4.0]];
358
- let extracted = extract_embeddings(
359
- output.view().into_dyn(),
360
- empty_attention_mask(),
361
- &test_config(ExtractorMode::Raw),
362
- )
363
- .unwrap();
261
+ let extracted =
262
+ extract_embeddings(output.view().into_dyn(), empty_attention_mask(), &test_config(ExtractorMode::Raw))
263
+ .unwrap();
364
264
 
365
265
  assert_eq!(extracted, output);
366
266
  }
367
267
 
368
268
  #[test]
369
269
  fn extract_embeddings_token_selects_without_copying_full_sequence() {
370
- let output = array![
371
- [[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
372
- [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
373
- ];
270
+ let output = array![[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]];
374
271
  let expected = array![[3.0f32, 4.0], [9.0, 10.0]];
375
- let extracted = extract_embeddings(
376
- output.view().into_dyn(),
377
- empty_attention_mask(),
378
- &test_config(ExtractorMode::Token(1)),
379
- )
380
- .unwrap();
272
+ let extracted =
273
+ extract_embeddings(output.view().into_dyn(), empty_attention_mask(), &test_config(ExtractorMode::Token(1)))
274
+ .unwrap();
381
275
 
382
276
  assert_eq!(extracted, expected);
383
277
  }
384
278
 
385
279
  #[test]
386
280
  fn extract_embeddings_mean_pool_uses_output_view_and_attention_mask() {
387
- let output = array![
388
- [[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]],
389
- [[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]
390
- ];
281
+ let output = array![[[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]], [[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]];
391
282
  let attention_mask = array![[1_i64, 1, 0], [0, 1, 1]];
392
283
  let expected = array![[3.0f32, 5.0], [8.0, 10.0]];
393
- let extracted = extract_embeddings(
394
- output.view().into_dyn(),
395
- attention_mask.view(),
396
- &test_config(ExtractorMode::MeanPool),
397
- )
398
- .unwrap();
284
+ let extracted =
285
+ extract_embeddings(output.view().into_dyn(), attention_mask.view(), &test_config(ExtractorMode::MeanPool))
286
+ .unwrap();
399
287
 
400
288
  assert_eq!(extracted, expected);
401
289
  }