gte 0.0.13-aarch64-linux → 0.0.15-aarch64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,224 +4,206 @@ use crate::pipeline::{extract_output_tensor, InputTensors};
4
4
  use crate::postprocess::mean_pool;
5
5
  use crate::tokenizer::Tokenized;
6
6
  use ndarray::{Array2, ArrayView2, ArrayViewD, Ix2};
7
- use ort::execution_providers::{
8
- CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider,
9
- };
7
+ use ort::execution_providers::{CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider};
10
8
  use ort::session::{OutputSelector, RunOptions, Session};
9
+ use parking_lot::Mutex;
11
10
  use std::path::{Path, PathBuf};
12
11
  use std::sync::atomic::{AtomicUsize, Ordering};
13
- use std::sync::{Condvar, Mutex};
14
-
15
- pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
16
- let opt_level = match config.optimization_level {
17
- 0 => ort::session::builder::GraphOptimizationLevel::Disable,
18
- 1 => ort::session::builder::GraphOptimizationLevel::Level1,
19
- 2 => ort::session::builder::GraphOptimizationLevel::Level2,
20
- _ => ort::session::builder::GraphOptimizationLevel::Level3,
21
- };
22
-
23
- fn ort_err(e: impl std::fmt::Display) -> GteError {
24
- GteError::Ort(e.to_string())
25
- }
26
-
27
- let mut builder = Session::builder()
28
- .map_err(ort_err)?
29
- .with_optimization_level(opt_level)
30
- .map_err(ort_err)?;
31
-
32
- let providers = preferred_execution_providers(config.execution_providers.as_deref());
33
- if !providers.is_empty() {
34
- builder = builder
35
- .with_execution_providers(providers)
36
- .map_err(ort_err)?;
37
- }
38
-
39
- builder.commit_from_file(model_path).map_err(ort_err)
40
- }
12
+ use std::sync::Arc;
41
13
 
42
14
  // ---------------------------------------------------------------------------
43
- // Session pool
15
+ // Lazy session pool — starts with 1 session, grows on contention, capped.
16
+ //
17
+ // Pool max is resolved in order:
18
+ // 1. GTE_SESSION_POOL_SIZE env var (explicit override)
19
+ // 2. Auto: 2 (conservative: 2× pure Ruby memory at peak, no OOM risk)
20
+ //
21
+ // At idle the pool holds 1 session (same memory as pure Ruby's single
22
+ // OnnxRuntime::Model). When all existing sessions are busy and the cap
23
+ // hasn't been reached, a new session is created on-demand.
44
24
  // ---------------------------------------------------------------------------
45
25
 
46
- fn pool_capacity() -> usize {
47
- let available = std::thread::available_parallelism()
48
- .map(|n| n.get())
49
- .unwrap_or(1);
50
- parse_pool_capacity_override().map_or(available, |cap| cap.min(available).max(1))
26
+ fn resolve_pool_cap() -> usize {
27
+ if let Some(n) =
28
+ std::env::var("GTE_SESSION_POOL_SIZE").ok().and_then(|v| v.trim().parse::<usize>().ok()).filter(|&n| n > 0)
29
+ {
30
+ return n;
31
+ }
32
+ 2
51
33
  }
52
34
 
53
- fn parse_pool_capacity_override() -> Option<usize> {
54
- let raw = std::env::var("GTE_SESSION_POOL_CAP").ok()?;
55
- let parsed = raw.trim().parse::<usize>().ok()?;
56
- (parsed > 0).then_some(parsed)
35
+ pub struct SessionPool {
36
+ inner: Mutex<PoolInner>,
37
+ next_idx: AtomicUsize,
38
+ cap: usize,
57
39
  }
58
40
 
59
- pub struct SessionPool {
60
- sessions: Mutex<Vec<Session>>,
61
- available: Condvar,
62
- created: AtomicUsize,
63
- capacity: usize,
41
+ struct PoolInner {
42
+ sessions: Vec<Arc<Mutex<Session>>>,
64
43
  model_path: PathBuf,
65
44
  build_config: ModelConfig,
66
45
  }
67
46
 
68
47
  impl SessionPool {
69
- pub fn new(initial: Session, model_path: PathBuf, build_config: ModelConfig) -> Self {
70
- let capacity = pool_capacity();
71
- Self {
72
- sessions: Mutex::new(vec![initial]),
73
- available: Condvar::new(),
74
- created: AtomicUsize::new(1),
75
- capacity,
76
- model_path,
77
- build_config,
78
- }
79
- }
80
-
81
- pub fn acquire(&self) -> Result<PooledSession<'_>> {
82
- if let Some(session) = self.take_available_session() {
83
- return Ok(PooledSession {
84
- pool: self,
85
- session: Some(session),
86
- });
87
- }
88
-
89
- if let Some(session) = self.try_grow()? {
90
- return Ok(PooledSession {
91
- pool: self,
92
- session: Some(session),
93
- });
94
- }
95
-
96
- let session = self.wait_for_session();
97
- Ok(PooledSession {
98
- pool: self,
99
- session: Some(session),
48
+ pub fn new(initial: Session, model_path: &Path, build_config: &ModelConfig) -> Result<Self> {
49
+ let cap = resolve_pool_cap();
50
+ let sessions = vec![Arc::new(Mutex::new(initial))];
51
+
52
+ Ok(Self {
53
+ inner: Mutex::new(PoolInner {
54
+ sessions,
55
+ model_path: model_path.to_path_buf(),
56
+ build_config: build_config.clone(),
57
+ }),
58
+ next_idx: AtomicUsize::new(0),
59
+ cap,
100
60
  })
101
61
  }
102
62
 
103
- fn release(&self, session: Session) {
104
- self.sessions.lock().unwrap().push(session);
105
- self.available.notify_one();
63
+ pub fn run(&self, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
64
+ self.with_session(|session| run_session(session, tokenized, config))
106
65
  }
107
66
 
108
- fn take_available_session(&self) -> Option<Session> {
109
- self.sessions.lock().unwrap().pop()
110
- }
67
+ pub fn with_session<F, R>(&self, f: F) -> Result<R>
68
+ where
69
+ F: FnOnce(&mut Session) -> Result<R>,
70
+ {
71
+ const SPIN_LIMIT: u32 = 64;
111
72
 
112
- fn try_grow(&self) -> Result<Option<Session>> {
113
- let grew = self
114
- .created
115
- .fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
116
- (count < self.capacity).then_some(count + 1)
117
- });
118
- if grew.is_err() {
119
- return Ok(None);
120
- }
73
+ loop {
74
+ // Snapshot the pool under the outer lock so the scan below
75
+ // doesn't contend on that lock at all.
76
+ let arcs: Vec<Arc<Mutex<Session>>> = {
77
+ let inner = self.inner.lock();
78
+ inner.sessions.clone()
79
+ };
80
+ let len = arcs.len();
81
+ let start = self.next_idx.fetch_add(1, Ordering::Relaxed) % len;
82
+
83
+ for offset in 0..len {
84
+ let idx = (start + offset) % len;
85
+ if let Some(mut guard) = arcs[idx].try_lock() {
86
+ return f(&mut guard);
87
+ }
88
+ }
121
89
 
122
- match build_session(&self.model_path, &self.build_config) {
123
- Ok(session) => Ok(Some(session)),
124
- Err(error) => {
125
- self.created.fetch_sub(1, Ordering::AcqRel);
126
- Err(error)
90
+ // All sessions busy — try to grow the pool
91
+ let grew = {
92
+ let mut inner = self.inner.lock();
93
+ if inner.sessions.len() < self.cap {
94
+ match build_session(&inner.model_path, &inner.build_config) {
95
+ Ok(session) => {
96
+ inner.sessions.push(Arc::new(Mutex::new(session)));
97
+ true
98
+ }
99
+ Err(e) => return Err(e),
100
+ }
101
+ } else {
102
+ false
103
+ }
104
+ };
105
+
106
+ if grew {
107
+ continue;
127
108
  }
128
- }
129
- }
130
109
 
131
- fn wait_for_session(&self) -> Session {
132
- let mut lock = self.sessions.lock().unwrap();
133
- loop {
134
- if let Some(session) = lock.pop() {
135
- return session;
110
+ // At cap spin briefly, then block on a session
111
+ let idx = self.next_idx.fetch_add(1, Ordering::Relaxed) % len;
112
+ let arc = Arc::clone(&arcs[idx]);
113
+
114
+ for _ in 0..SPIN_LIMIT {
115
+ if let Some(mut guard) = arc.try_lock() {
116
+ return f(&mut guard);
117
+ }
118
+ std::hint::spin_loop();
136
119
  }
137
- lock = self.available.wait(lock).unwrap();
120
+
121
+ let mut guard = arc.lock();
122
+ return f(&mut guard);
138
123
  }
139
124
  }
140
125
  }
141
126
 
142
- pub struct PooledSession<'a> {
143
- pool: &'a SessionPool,
144
- session: Option<Session>,
145
- }
127
+ // ---------------------------------------------------------------------------
128
+ // Session construction
129
+ // ---------------------------------------------------------------------------
146
130
 
147
- impl std::ops::Deref for PooledSession<'_> {
148
- type Target = Session;
149
- fn deref(&self) -> &Session {
150
- self.session.as_ref().unwrap()
131
+ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
132
+ fn ort_err(e: impl std::fmt::Display) -> GteError {
133
+ GteError::Ort(e.to_string())
151
134
  }
152
- }
153
135
 
154
- impl std::ops::DerefMut for PooledSession<'_> {
155
- fn deref_mut(&mut self) -> &mut Session {
156
- self.session.as_mut().unwrap()
157
- }
158
- }
136
+ let opt_level = match config.optimization_level {
137
+ 0 => ort::session::builder::GraphOptimizationLevel::Disable,
138
+ 1 => ort::session::builder::GraphOptimizationLevel::Level1,
139
+ 2 => ort::session::builder::GraphOptimizationLevel::Level2,
140
+ _ => ort::session::builder::GraphOptimizationLevel::Level3,
141
+ };
159
142
 
160
- impl Drop for PooledSession<'_> {
161
- fn drop(&mut self) {
162
- if let Some(s) = self.session.take() {
163
- self.pool.release(s);
164
- }
165
- }
166
- }
143
+ let mut builder = Session::builder().map_err(ort_err)?.with_optimization_level(opt_level).map_err(ort_err)?;
167
144
 
168
- // ---------------------------------------------------------------------------
145
+ let intra_threads = std::env::var("GTE_INTRA_OP_NUM_THREADS")
146
+ .ok()
147
+ .and_then(|v| v.trim().parse::<usize>().ok())
148
+ .unwrap_or_else(|| std::thread::available_parallelism().map(|n| n.get().min(4)).unwrap_or(1));
149
+ builder = builder.with_intra_threads(intra_threads).map_err(ort_err)?;
169
150
 
170
- fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
171
- let order = resolve_provider_order(order_override);
151
+ let inter_threads =
152
+ std::env::var("GTE_INTER_OP_NUM_THREADS").ok().and_then(|v| v.trim().parse::<usize>().ok()).unwrap_or(1);
153
+ builder = builder.with_inter_threads(inter_threads).map_err(ort_err)?;
172
154
 
173
- let mut providers = Vec::new();
174
- for provider in parse_provider_registrations(order.as_str()) {
175
- match provider {
176
- "xnnpack" => {
177
- providers.push(XNNPACKExecutionProvider::default().build().fail_silently())
178
- }
179
- "coreml" => providers.push(CoreMLExecutionProvider::default().build().fail_silently()),
180
- _ => {}
181
- }
155
+ let providers = match config.execution_providers.as_deref() {
156
+ Some(override_val) => preferred_execution_providers(Some(override_val)),
157
+ None => auto_detect_providers(),
158
+ };
159
+ if !providers.is_empty() {
160
+ builder = builder.with_execution_providers(providers).map_err(ort_err)?;
182
161
  }
183
- providers
184
- }
185
162
 
186
- fn resolve_provider_order(order_override: Option<&str>) -> String {
187
- let env_order = std::env::var("GTE_EXECUTION_PROVIDERS").ok();
188
- resolve_provider_order_with_env(order_override, env_order.as_deref())
163
+ builder.commit_from_file(model_path).map_err(ort_err)
189
164
  }
190
165
 
191
- fn resolve_provider_order_with_env(
192
- order_override: Option<&str>,
193
- env_order: Option<&str>,
194
- ) -> String {
195
- order_override
196
- .or(env_order)
197
- .unwrap_or("cpu")
198
- .to_ascii_lowercase()
166
+ fn auto_detect_providers() -> Vec<ExecutionProviderDispatch> {
167
+ let mut providers = Vec::new();
168
+ #[cfg(target_arch = "aarch64")]
169
+ providers.push(XNNPACKExecutionProvider::default().build().fail_silently());
170
+ providers
199
171
  }
200
172
 
201
- fn parse_provider_registrations(order: &str) -> Vec<&str> {
202
- let mut providers = Vec::new();
203
- for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
204
- match provider {
205
- "xnnpack" | "coreml" => providers.push(provider),
206
- "none" | "cpu" => {}
207
- _ => {}
208
- }
173
+ fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
174
+ let order = match order_override {
175
+ Some(s) => s.to_ascii_lowercase(),
176
+ None => return auto_detect_providers(),
177
+ };
178
+
179
+ if order.is_empty() || order == "cpu" || order == "none" {
180
+ return Vec::new();
209
181
  }
182
+
183
+ let providers: Vec<_> = order
184
+ .split(',')
185
+ .map(str::trim)
186
+ .filter(|p| !p.is_empty())
187
+ .filter_map(|provider| match provider {
188
+ "xnnpack" => Some(XNNPACKExecutionProvider::default().build().fail_silently()),
189
+ "coreml" => Some(CoreMLExecutionProvider::default().build().fail_silently()),
190
+ _ => None,
191
+ })
192
+ .collect();
210
193
  providers
211
194
  }
212
195
 
213
- pub fn run_session(
214
- session: &mut Session,
215
- tokenized: &Tokenized,
216
- config: &ModelConfig,
217
- ) -> Result<Array2<f32>> {
196
+ // ---------------------------------------------------------------------------
197
+ // Run a single inference
198
+ // ---------------------------------------------------------------------------
199
+
200
+ pub fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
218
201
  let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
219
202
  let run_opts = RunOptions::new()
220
203
  .map_err(|e| GteError::Ort(e.to_string()))?
221
204
  .with_outputs(OutputSelector::no_default().with(config.output_tensor.as_str()));
222
- let outputs = session
223
- .run_with_options(input_tensors.inputs, &run_opts)
224
- .map_err(|e| GteError::Ort(e.to_string()))?;
205
+ let outputs =
206
+ session.run_with_options(input_tensors.inputs, &run_opts).map_err(|e| GteError::Ort(e.to_string()))?;
225
207
  let array = extract_output_tensor(&outputs, config.output_tensor.as_str())?;
226
208
 
227
209
  extract_embeddings(array, input_tensors.attention_mask, config)
@@ -237,26 +219,21 @@ fn extract_embeddings(
237
219
  let shape = array.shape();
238
220
  if shape.len() != 3 || idx >= shape[1] {
239
221
  return Err(GteError::Inference(format!(
240
- "token extraction index {} out of bounds for output shape {:?}",
241
- idx, shape
222
+ "token extraction index {idx} out of bounds for output shape {shape:?}"
242
223
  )));
243
224
  }
244
225
  Ok(array.slice(ndarray::s![.., idx, ..]).into_owned())
245
226
  }
246
227
  ExtractorMode::MeanPool => {
247
228
  let ndim = array.ndim();
248
- let hidden_states = array.into_dimensionality::<ndarray::Ix3>().map_err(|_| {
249
- GteError::Inference(format!(
250
- "mean pooling requires rank-3 output, got rank {}",
251
- ndim
252
- ))
253
- })?;
229
+ let hidden_states = array
230
+ .into_dimensionality::<ndarray::Ix3>()
231
+ .map_err(|_| GteError::Inference(format!("mean pooling requires rank-3 output, got rank {ndim}")))?;
254
232
  mean_pool(hidden_states, attention_mask)
255
233
  }
256
- ExtractorMode::Raw => array
257
- .into_dimensionality::<Ix2>()
258
- .map(|view| view.to_owned())
259
- .map_err(|e| GteError::Shape(e.to_string())),
234
+ ExtractorMode::Raw => {
235
+ array.into_dimensionality::<Ix2>().map(|view| view.to_owned()).map_err(|e| GteError::Shape(e.to_string()))
236
+ }
260
237
  }
261
238
  }
262
239
 
@@ -265,10 +242,22 @@ mod tests {
265
242
  use crate::model_config::{ExtractorMode, ModelConfig, PaddingMode};
266
243
  use ndarray::{array, ArrayView2};
267
244
 
268
- use super::{
269
- extract_embeddings, parse_pool_capacity_override, parse_provider_registrations,
270
- resolve_provider_order_with_env,
271
- };
245
+ use super::extract_embeddings;
246
+
247
+ fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
248
+ order_override.or(env_order).unwrap_or("cpu").to_ascii_lowercase()
249
+ }
250
+
251
+ fn parse_provider_registrations(order: &str) -> Vec<&str> {
252
+ let mut providers = Vec::new();
253
+ for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
254
+ match provider {
255
+ "xnnpack" | "coreml" => providers.push(provider),
256
+ _ => {}
257
+ }
258
+ }
259
+ providers
260
+ }
272
261
 
273
262
  fn test_config(mode: ExtractorMode) -> ModelConfig {
274
263
  ModelConfig {
@@ -280,6 +269,8 @@ mod tests {
280
269
  with_attention_mask: true,
281
270
  optimization_level: 3,
282
271
  execution_providers: None,
272
+ lowercase_input: false,
273
+ max_input_chars: None,
283
274
  }
284
275
  }
285
276
 
@@ -309,93 +300,45 @@ mod tests {
309
300
 
310
301
  #[test]
311
302
  fn resolve_provider_order_prefers_override() {
312
- assert_eq!(
313
- resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")),
314
- "xnnpack"
315
- );
303
+ assert_eq!(resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")), "xnnpack");
316
304
  assert_eq!(resolve_provider_order_with_env(Some("CPU"), None), "cpu");
317
305
  }
318
306
 
319
307
  #[test]
320
308
  fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
321
- assert_eq!(
322
- resolve_provider_order_with_env(None, Some("coreml")),
323
- "coreml"
324
- );
309
+ assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
325
310
  assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
326
311
  }
327
312
 
328
- #[test]
329
- fn parse_pool_capacity_override_uses_positive_integer_only() {
330
- unsafe {
331
- std::env::remove_var("GTE_SESSION_POOL_CAP");
332
- }
333
- assert_eq!(parse_pool_capacity_override(), None);
334
-
335
- unsafe {
336
- std::env::set_var("GTE_SESSION_POOL_CAP", "0");
337
- }
338
- assert_eq!(parse_pool_capacity_override(), None);
339
-
340
- unsafe {
341
- std::env::set_var("GTE_SESSION_POOL_CAP", "4");
342
- }
343
- assert_eq!(parse_pool_capacity_override(), Some(4));
344
-
345
- unsafe {
346
- std::env::set_var("GTE_SESSION_POOL_CAP", "abc");
347
- }
348
- assert_eq!(parse_pool_capacity_override(), None);
349
-
350
- unsafe {
351
- std::env::remove_var("GTE_SESSION_POOL_CAP");
352
- }
353
- }
354
-
355
313
  #[test]
356
314
  fn extract_embeddings_raw_copies_only_final_matrix() {
357
315
  let output = array![[1.0f32, 2.0], [3.0, 4.0]];
358
- let extracted = extract_embeddings(
359
- output.view().into_dyn(),
360
- empty_attention_mask(),
361
- &test_config(ExtractorMode::Raw),
362
- )
363
- .unwrap();
316
+ let extracted =
317
+ extract_embeddings(output.view().into_dyn(), empty_attention_mask(), &test_config(ExtractorMode::Raw))
318
+ .unwrap();
364
319
 
365
320
  assert_eq!(extracted, output);
366
321
  }
367
322
 
368
323
  #[test]
369
324
  fn extract_embeddings_token_selects_without_copying_full_sequence() {
370
- let output = array![
371
- [[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
372
- [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
373
- ];
325
+ let output = array![[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]];
374
326
  let expected = array![[3.0f32, 4.0], [9.0, 10.0]];
375
- let extracted = extract_embeddings(
376
- output.view().into_dyn(),
377
- empty_attention_mask(),
378
- &test_config(ExtractorMode::Token(1)),
379
- )
380
- .unwrap();
327
+ let extracted =
328
+ extract_embeddings(output.view().into_dyn(), empty_attention_mask(), &test_config(ExtractorMode::Token(1)))
329
+ .unwrap();
381
330
 
382
331
  assert_eq!(extracted, expected);
383
332
  }
384
333
 
385
334
  #[test]
386
335
  fn extract_embeddings_mean_pool_uses_output_view_and_attention_mask() {
387
- let output = array![
388
- [[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]],
389
- [[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]
390
- ];
336
+ let output = array![[[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]], [[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]];
391
337
  let attention_mask = array![[1_i64, 1, 0], [0, 1, 1]];
392
338
  let expected = array![[3.0f32, 5.0], [8.0, 10.0]];
393
- let extracted = extract_embeddings(
394
- output.view().into_dyn(),
395
- attention_mask.view(),
396
- &test_config(ExtractorMode::MeanPool),
397
- )
398
- .unwrap();
339
+ let extracted =
340
+ extract_embeddings(output.view().into_dyn(), attention_mask.view(), &test_config(ExtractorMode::MeanPool))
341
+ .unwrap();
399
342
 
400
343
  assert_eq!(extracted, expected);
401
344
  }