gte 0.0.13-aarch64-linux → 0.0.15-aarch64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +93 -27
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +27 -4
- data/ext/gte/benches/hot_path.rs +20 -54
- data/ext/gte/build.rs +2 -6
- data/ext/gte/rustfmt.toml +5 -0
- data/ext/gte/src/embedder.rs +71 -43
- data/ext/gte/src/error.rs +4 -4
- data/ext/gte/src/lib.rs +1 -1
- data/ext/gte/src/model_config.rs +4 -0
- data/ext/gte/src/model_profile.rs +26 -87
- data/ext/gte/src/pipeline.rs +11 -30
- data/ext/gte/src/postprocess.rs +8 -14
- data/ext/gte/src/reranker.rs +50 -50
- data/ext/gte/src/ruby_embedder.rs +48 -53
- data/ext/gte/src/session.rs +187 -244
- data/ext/gte/src/tokenizer.rs +51 -125
- data/ext/gte/tests/inference_integration_test.rs +8 -18
- data/ext/gte/tests/padding_regression_test.rs +13 -26
- data/ext/gte/tests/tokenizer_unit_test.rs +10 -24
- data/lib/gte/config.rb +2 -1
- data/lib/gte/embedder.rb +6 -2
- data/lib/gte/gte.so +0 -0
- data/lib/gte/reranker.rb +3 -1
- data/lib/gte.rb +6 -0
- metadata +3 -2
data/ext/gte/src/session.rs
CHANGED
|
@@ -4,224 +4,206 @@ use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
|
4
4
|
use crate::postprocess::mean_pool;
|
|
5
5
|
use crate::tokenizer::Tokenized;
|
|
6
6
|
use ndarray::{Array2, ArrayView2, ArrayViewD, Ix2};
|
|
7
|
-
use ort::execution_providers::{
|
|
8
|
-
CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider,
|
|
9
|
-
};
|
|
7
|
+
use ort::execution_providers::{CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider};
|
|
10
8
|
use ort::session::{OutputSelector, RunOptions, Session};
|
|
9
|
+
use parking_lot::Mutex;
|
|
11
10
|
use std::path::{Path, PathBuf};
|
|
12
11
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
13
|
-
use std::sync::
|
|
14
|
-
|
|
15
|
-
pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
|
|
16
|
-
let opt_level = match config.optimization_level {
|
|
17
|
-
0 => ort::session::builder::GraphOptimizationLevel::Disable,
|
|
18
|
-
1 => ort::session::builder::GraphOptimizationLevel::Level1,
|
|
19
|
-
2 => ort::session::builder::GraphOptimizationLevel::Level2,
|
|
20
|
-
_ => ort::session::builder::GraphOptimizationLevel::Level3,
|
|
21
|
-
};
|
|
22
|
-
|
|
23
|
-
fn ort_err(e: impl std::fmt::Display) -> GteError {
|
|
24
|
-
GteError::Ort(e.to_string())
|
|
25
|
-
}
|
|
26
|
-
|
|
27
|
-
let mut builder = Session::builder()
|
|
28
|
-
.map_err(ort_err)?
|
|
29
|
-
.with_optimization_level(opt_level)
|
|
30
|
-
.map_err(ort_err)?;
|
|
31
|
-
|
|
32
|
-
let providers = preferred_execution_providers(config.execution_providers.as_deref());
|
|
33
|
-
if !providers.is_empty() {
|
|
34
|
-
builder = builder
|
|
35
|
-
.with_execution_providers(providers)
|
|
36
|
-
.map_err(ort_err)?;
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
builder.commit_from_file(model_path).map_err(ort_err)
|
|
40
|
-
}
|
|
12
|
+
use std::sync::Arc;
|
|
41
13
|
|
|
42
14
|
// ---------------------------------------------------------------------------
|
|
43
|
-
//
|
|
15
|
+
// Lazy session pool — starts with 1 session, grows on contention, capped.
|
|
16
|
+
//
|
|
17
|
+
// Pool max is resolved in order:
|
|
18
|
+
// 1. GTE_SESSION_POOL_SIZE env var (explicit override)
|
|
19
|
+
// 2. Auto: 2 (conservative: 2× pure Ruby memory at peak, no OOM risk)
|
|
20
|
+
//
|
|
21
|
+
// At idle the pool holds 1 session (same memory as pure Ruby's single
|
|
22
|
+
// OnnxRuntime::Model). When all existing sessions are busy and the cap
|
|
23
|
+
// hasn't been reached, a new session is created on-demand.
|
|
44
24
|
// ---------------------------------------------------------------------------
|
|
45
25
|
|
|
46
|
-
fn
|
|
47
|
-
let
|
|
48
|
-
.
|
|
49
|
-
|
|
50
|
-
|
|
26
|
+
fn resolve_pool_cap() -> usize {
|
|
27
|
+
if let Some(n) =
|
|
28
|
+
std::env::var("GTE_SESSION_POOL_SIZE").ok().and_then(|v| v.trim().parse::<usize>().ok()).filter(|&n| n > 0)
|
|
29
|
+
{
|
|
30
|
+
return n;
|
|
31
|
+
}
|
|
32
|
+
2
|
|
51
33
|
}
|
|
52
34
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
35
|
+
pub struct SessionPool {
|
|
36
|
+
inner: Mutex<PoolInner>,
|
|
37
|
+
next_idx: AtomicUsize,
|
|
38
|
+
cap: usize,
|
|
57
39
|
}
|
|
58
40
|
|
|
59
|
-
|
|
60
|
-
sessions: Mutex<
|
|
61
|
-
available: Condvar,
|
|
62
|
-
created: AtomicUsize,
|
|
63
|
-
capacity: usize,
|
|
41
|
+
struct PoolInner {
|
|
42
|
+
sessions: Vec<Arc<Mutex<Session>>>,
|
|
64
43
|
model_path: PathBuf,
|
|
65
44
|
build_config: ModelConfig,
|
|
66
45
|
}
|
|
67
46
|
|
|
68
47
|
impl SessionPool {
|
|
69
|
-
pub fn new(initial: Session, model_path:
|
|
70
|
-
let
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
pub fn acquire(&self) -> Result<PooledSession<'_>> {
|
|
82
|
-
if let Some(session) = self.take_available_session() {
|
|
83
|
-
return Ok(PooledSession {
|
|
84
|
-
pool: self,
|
|
85
|
-
session: Some(session),
|
|
86
|
-
});
|
|
87
|
-
}
|
|
88
|
-
|
|
89
|
-
if let Some(session) = self.try_grow()? {
|
|
90
|
-
return Ok(PooledSession {
|
|
91
|
-
pool: self,
|
|
92
|
-
session: Some(session),
|
|
93
|
-
});
|
|
94
|
-
}
|
|
95
|
-
|
|
96
|
-
let session = self.wait_for_session();
|
|
97
|
-
Ok(PooledSession {
|
|
98
|
-
pool: self,
|
|
99
|
-
session: Some(session),
|
|
48
|
+
pub fn new(initial: Session, model_path: &Path, build_config: &ModelConfig) -> Result<Self> {
|
|
49
|
+
let cap = resolve_pool_cap();
|
|
50
|
+
let sessions = vec![Arc::new(Mutex::new(initial))];
|
|
51
|
+
|
|
52
|
+
Ok(Self {
|
|
53
|
+
inner: Mutex::new(PoolInner {
|
|
54
|
+
sessions,
|
|
55
|
+
model_path: model_path.to_path_buf(),
|
|
56
|
+
build_config: build_config.clone(),
|
|
57
|
+
}),
|
|
58
|
+
next_idx: AtomicUsize::new(0),
|
|
59
|
+
cap,
|
|
100
60
|
})
|
|
101
61
|
}
|
|
102
62
|
|
|
103
|
-
fn
|
|
104
|
-
self.
|
|
105
|
-
self.available.notify_one();
|
|
63
|
+
pub fn run(&self, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
|
|
64
|
+
self.with_session(|session| run_session(session, tokenized, config))
|
|
106
65
|
}
|
|
107
66
|
|
|
108
|
-
fn
|
|
109
|
-
|
|
110
|
-
|
|
67
|
+
pub fn with_session<F, R>(&self, f: F) -> Result<R>
|
|
68
|
+
where
|
|
69
|
+
F: FnOnce(&mut Session) -> Result<R>,
|
|
70
|
+
{
|
|
71
|
+
const SPIN_LIMIT: u32 = 64;
|
|
111
72
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
.
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
73
|
+
loop {
|
|
74
|
+
// Snapshot the pool under the outer lock so the scan below
|
|
75
|
+
// doesn't contend on that lock at all.
|
|
76
|
+
let arcs: Vec<Arc<Mutex<Session>>> = {
|
|
77
|
+
let inner = self.inner.lock();
|
|
78
|
+
inner.sessions.clone()
|
|
79
|
+
};
|
|
80
|
+
let len = arcs.len();
|
|
81
|
+
let start = self.next_idx.fetch_add(1, Ordering::Relaxed) % len;
|
|
82
|
+
|
|
83
|
+
for offset in 0..len {
|
|
84
|
+
let idx = (start + offset) % len;
|
|
85
|
+
if let Some(mut guard) = arcs[idx].try_lock() {
|
|
86
|
+
return f(&mut guard);
|
|
87
|
+
}
|
|
88
|
+
}
|
|
121
89
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
90
|
+
// All sessions busy — try to grow the pool
|
|
91
|
+
let grew = {
|
|
92
|
+
let mut inner = self.inner.lock();
|
|
93
|
+
if inner.sessions.len() < self.cap {
|
|
94
|
+
match build_session(&inner.model_path, &inner.build_config) {
|
|
95
|
+
Ok(session) => {
|
|
96
|
+
inner.sessions.push(Arc::new(Mutex::new(session)));
|
|
97
|
+
true
|
|
98
|
+
}
|
|
99
|
+
Err(e) => return Err(e),
|
|
100
|
+
}
|
|
101
|
+
} else {
|
|
102
|
+
false
|
|
103
|
+
}
|
|
104
|
+
};
|
|
105
|
+
|
|
106
|
+
if grew {
|
|
107
|
+
continue;
|
|
127
108
|
}
|
|
128
|
-
}
|
|
129
|
-
}
|
|
130
109
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
110
|
+
// At cap — spin briefly, then block on a session
|
|
111
|
+
let idx = self.next_idx.fetch_add(1, Ordering::Relaxed) % len;
|
|
112
|
+
let arc = Arc::clone(&arcs[idx]);
|
|
113
|
+
|
|
114
|
+
for _ in 0..SPIN_LIMIT {
|
|
115
|
+
if let Some(mut guard) = arc.try_lock() {
|
|
116
|
+
return f(&mut guard);
|
|
117
|
+
}
|
|
118
|
+
std::hint::spin_loop();
|
|
136
119
|
}
|
|
137
|
-
|
|
120
|
+
|
|
121
|
+
let mut guard = arc.lock();
|
|
122
|
+
return f(&mut guard);
|
|
138
123
|
}
|
|
139
124
|
}
|
|
140
125
|
}
|
|
141
126
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
}
|
|
127
|
+
// ---------------------------------------------------------------------------
|
|
128
|
+
// Session construction
|
|
129
|
+
// ---------------------------------------------------------------------------
|
|
146
130
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
self.session.as_ref().unwrap()
|
|
131
|
+
pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
|
|
132
|
+
fn ort_err(e: impl std::fmt::Display) -> GteError {
|
|
133
|
+
GteError::Ort(e.to_string())
|
|
151
134
|
}
|
|
152
|
-
}
|
|
153
135
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
136
|
+
let opt_level = match config.optimization_level {
|
|
137
|
+
0 => ort::session::builder::GraphOptimizationLevel::Disable,
|
|
138
|
+
1 => ort::session::builder::GraphOptimizationLevel::Level1,
|
|
139
|
+
2 => ort::session::builder::GraphOptimizationLevel::Level2,
|
|
140
|
+
_ => ort::session::builder::GraphOptimizationLevel::Level3,
|
|
141
|
+
};
|
|
159
142
|
|
|
160
|
-
|
|
161
|
-
fn drop(&mut self) {
|
|
162
|
-
if let Some(s) = self.session.take() {
|
|
163
|
-
self.pool.release(s);
|
|
164
|
-
}
|
|
165
|
-
}
|
|
166
|
-
}
|
|
143
|
+
let mut builder = Session::builder().map_err(ort_err)?.with_optimization_level(opt_level).map_err(ort_err)?;
|
|
167
144
|
|
|
168
|
-
|
|
145
|
+
let intra_threads = std::env::var("GTE_INTRA_OP_NUM_THREADS")
|
|
146
|
+
.ok()
|
|
147
|
+
.and_then(|v| v.trim().parse::<usize>().ok())
|
|
148
|
+
.unwrap_or_else(|| std::thread::available_parallelism().map(|n| n.get().min(4)).unwrap_or(1));
|
|
149
|
+
builder = builder.with_intra_threads(intra_threads).map_err(ort_err)?;
|
|
169
150
|
|
|
170
|
-
|
|
171
|
-
|
|
151
|
+
let inter_threads =
|
|
152
|
+
std::env::var("GTE_INTER_OP_NUM_THREADS").ok().and_then(|v| v.trim().parse::<usize>().ok()).unwrap_or(1);
|
|
153
|
+
builder = builder.with_inter_threads(inter_threads).map_err(ort_err)?;
|
|
172
154
|
|
|
173
|
-
let
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
"coreml" => providers.push(CoreMLExecutionProvider::default().build().fail_silently()),
|
|
180
|
-
_ => {}
|
|
181
|
-
}
|
|
155
|
+
let providers = match config.execution_providers.as_deref() {
|
|
156
|
+
Some(override_val) => preferred_execution_providers(Some(override_val)),
|
|
157
|
+
None => auto_detect_providers(),
|
|
158
|
+
};
|
|
159
|
+
if !providers.is_empty() {
|
|
160
|
+
builder = builder.with_execution_providers(providers).map_err(ort_err)?;
|
|
182
161
|
}
|
|
183
|
-
providers
|
|
184
|
-
}
|
|
185
162
|
|
|
186
|
-
|
|
187
|
-
let env_order = std::env::var("GTE_EXECUTION_PROVIDERS").ok();
|
|
188
|
-
resolve_provider_order_with_env(order_override, env_order.as_deref())
|
|
163
|
+
builder.commit_from_file(model_path).map_err(ort_err)
|
|
189
164
|
}
|
|
190
165
|
|
|
191
|
-
fn
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
.or(env_order)
|
|
197
|
-
.unwrap_or("cpu")
|
|
198
|
-
.to_ascii_lowercase()
|
|
166
|
+
fn auto_detect_providers() -> Vec<ExecutionProviderDispatch> {
|
|
167
|
+
let mut providers = Vec::new();
|
|
168
|
+
#[cfg(target_arch = "aarch64")]
|
|
169
|
+
providers.push(XNNPACKExecutionProvider::default().build().fail_silently());
|
|
170
|
+
providers
|
|
199
171
|
}
|
|
200
172
|
|
|
201
|
-
fn
|
|
202
|
-
let
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
173
|
+
fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
|
|
174
|
+
let order = match order_override {
|
|
175
|
+
Some(s) => s.to_ascii_lowercase(),
|
|
176
|
+
None => return auto_detect_providers(),
|
|
177
|
+
};
|
|
178
|
+
|
|
179
|
+
if order.is_empty() || order == "cpu" || order == "none" {
|
|
180
|
+
return Vec::new();
|
|
209
181
|
}
|
|
182
|
+
|
|
183
|
+
let providers: Vec<_> = order
|
|
184
|
+
.split(',')
|
|
185
|
+
.map(str::trim)
|
|
186
|
+
.filter(|p| !p.is_empty())
|
|
187
|
+
.filter_map(|provider| match provider {
|
|
188
|
+
"xnnpack" => Some(XNNPACKExecutionProvider::default().build().fail_silently()),
|
|
189
|
+
"coreml" => Some(CoreMLExecutionProvider::default().build().fail_silently()),
|
|
190
|
+
_ => None,
|
|
191
|
+
})
|
|
192
|
+
.collect();
|
|
210
193
|
providers
|
|
211
194
|
}
|
|
212
195
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
) -> Result<Array2<f32>> {
|
|
196
|
+
// ---------------------------------------------------------------------------
|
|
197
|
+
// Run a single inference
|
|
198
|
+
// ---------------------------------------------------------------------------
|
|
199
|
+
|
|
200
|
+
pub fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
|
|
218
201
|
let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
|
|
219
202
|
let run_opts = RunOptions::new()
|
|
220
203
|
.map_err(|e| GteError::Ort(e.to_string()))?
|
|
221
204
|
.with_outputs(OutputSelector::no_default().with(config.output_tensor.as_str()));
|
|
222
|
-
let outputs =
|
|
223
|
-
.run_with_options(input_tensors.inputs, &run_opts)
|
|
224
|
-
.map_err(|e| GteError::Ort(e.to_string()))?;
|
|
205
|
+
let outputs =
|
|
206
|
+
session.run_with_options(input_tensors.inputs, &run_opts).map_err(|e| GteError::Ort(e.to_string()))?;
|
|
225
207
|
let array = extract_output_tensor(&outputs, config.output_tensor.as_str())?;
|
|
226
208
|
|
|
227
209
|
extract_embeddings(array, input_tensors.attention_mask, config)
|
|
@@ -237,26 +219,21 @@ fn extract_embeddings(
|
|
|
237
219
|
let shape = array.shape();
|
|
238
220
|
if shape.len() != 3 || idx >= shape[1] {
|
|
239
221
|
return Err(GteError::Inference(format!(
|
|
240
|
-
"token extraction index {} out of bounds for output shape {:?}"
|
|
241
|
-
idx, shape
|
|
222
|
+
"token extraction index {idx} out of bounds for output shape {shape:?}"
|
|
242
223
|
)));
|
|
243
224
|
}
|
|
244
225
|
Ok(array.slice(ndarray::s![.., idx, ..]).into_owned())
|
|
245
226
|
}
|
|
246
227
|
ExtractorMode::MeanPool => {
|
|
247
228
|
let ndim = array.ndim();
|
|
248
|
-
let hidden_states = array
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
ndim
|
|
252
|
-
))
|
|
253
|
-
})?;
|
|
229
|
+
let hidden_states = array
|
|
230
|
+
.into_dimensionality::<ndarray::Ix3>()
|
|
231
|
+
.map_err(|_| GteError::Inference(format!("mean pooling requires rank-3 output, got rank {ndim}")))?;
|
|
254
232
|
mean_pool(hidden_states, attention_mask)
|
|
255
233
|
}
|
|
256
|
-
ExtractorMode::Raw =>
|
|
257
|
-
.into_dimensionality::<Ix2>()
|
|
258
|
-
|
|
259
|
-
.map_err(|e| GteError::Shape(e.to_string())),
|
|
234
|
+
ExtractorMode::Raw => {
|
|
235
|
+
array.into_dimensionality::<Ix2>().map(|view| view.to_owned()).map_err(|e| GteError::Shape(e.to_string()))
|
|
236
|
+
}
|
|
260
237
|
}
|
|
261
238
|
}
|
|
262
239
|
|
|
@@ -265,10 +242,22 @@ mod tests {
|
|
|
265
242
|
use crate::model_config::{ExtractorMode, ModelConfig, PaddingMode};
|
|
266
243
|
use ndarray::{array, ArrayView2};
|
|
267
244
|
|
|
268
|
-
use super::
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
245
|
+
use super::extract_embeddings;
|
|
246
|
+
|
|
247
|
+
fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
|
|
248
|
+
order_override.or(env_order).unwrap_or("cpu").to_ascii_lowercase()
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
fn parse_provider_registrations(order: &str) -> Vec<&str> {
|
|
252
|
+
let mut providers = Vec::new();
|
|
253
|
+
for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
|
|
254
|
+
match provider {
|
|
255
|
+
"xnnpack" | "coreml" => providers.push(provider),
|
|
256
|
+
_ => {}
|
|
257
|
+
}
|
|
258
|
+
}
|
|
259
|
+
providers
|
|
260
|
+
}
|
|
272
261
|
|
|
273
262
|
fn test_config(mode: ExtractorMode) -> ModelConfig {
|
|
274
263
|
ModelConfig {
|
|
@@ -280,6 +269,8 @@ mod tests {
|
|
|
280
269
|
with_attention_mask: true,
|
|
281
270
|
optimization_level: 3,
|
|
282
271
|
execution_providers: None,
|
|
272
|
+
lowercase_input: false,
|
|
273
|
+
max_input_chars: None,
|
|
283
274
|
}
|
|
284
275
|
}
|
|
285
276
|
|
|
@@ -309,93 +300,45 @@ mod tests {
|
|
|
309
300
|
|
|
310
301
|
#[test]
|
|
311
302
|
fn resolve_provider_order_prefers_override() {
|
|
312
|
-
assert_eq!(
|
|
313
|
-
resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")),
|
|
314
|
-
"xnnpack"
|
|
315
|
-
);
|
|
303
|
+
assert_eq!(resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")), "xnnpack");
|
|
316
304
|
assert_eq!(resolve_provider_order_with_env(Some("CPU"), None), "cpu");
|
|
317
305
|
}
|
|
318
306
|
|
|
319
307
|
#[test]
|
|
320
308
|
fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
|
|
321
|
-
assert_eq!(
|
|
322
|
-
resolve_provider_order_with_env(None, Some("coreml")),
|
|
323
|
-
"coreml"
|
|
324
|
-
);
|
|
309
|
+
assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
|
|
325
310
|
assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
|
|
326
311
|
}
|
|
327
312
|
|
|
328
|
-
#[test]
|
|
329
|
-
fn parse_pool_capacity_override_uses_positive_integer_only() {
|
|
330
|
-
unsafe {
|
|
331
|
-
std::env::remove_var("GTE_SESSION_POOL_CAP");
|
|
332
|
-
}
|
|
333
|
-
assert_eq!(parse_pool_capacity_override(), None);
|
|
334
|
-
|
|
335
|
-
unsafe {
|
|
336
|
-
std::env::set_var("GTE_SESSION_POOL_CAP", "0");
|
|
337
|
-
}
|
|
338
|
-
assert_eq!(parse_pool_capacity_override(), None);
|
|
339
|
-
|
|
340
|
-
unsafe {
|
|
341
|
-
std::env::set_var("GTE_SESSION_POOL_CAP", "4");
|
|
342
|
-
}
|
|
343
|
-
assert_eq!(parse_pool_capacity_override(), Some(4));
|
|
344
|
-
|
|
345
|
-
unsafe {
|
|
346
|
-
std::env::set_var("GTE_SESSION_POOL_CAP", "abc");
|
|
347
|
-
}
|
|
348
|
-
assert_eq!(parse_pool_capacity_override(), None);
|
|
349
|
-
|
|
350
|
-
unsafe {
|
|
351
|
-
std::env::remove_var("GTE_SESSION_POOL_CAP");
|
|
352
|
-
}
|
|
353
|
-
}
|
|
354
|
-
|
|
355
313
|
#[test]
|
|
356
314
|
fn extract_embeddings_raw_copies_only_final_matrix() {
|
|
357
315
|
let output = array![[1.0f32, 2.0], [3.0, 4.0]];
|
|
358
|
-
let extracted =
|
|
359
|
-
output.view().into_dyn(),
|
|
360
|
-
|
|
361
|
-
&test_config(ExtractorMode::Raw),
|
|
362
|
-
)
|
|
363
|
-
.unwrap();
|
|
316
|
+
let extracted =
|
|
317
|
+
extract_embeddings(output.view().into_dyn(), empty_attention_mask(), &test_config(ExtractorMode::Raw))
|
|
318
|
+
.unwrap();
|
|
364
319
|
|
|
365
320
|
assert_eq!(extracted, output);
|
|
366
321
|
}
|
|
367
322
|
|
|
368
323
|
#[test]
|
|
369
324
|
fn extract_embeddings_token_selects_without_copying_full_sequence() {
|
|
370
|
-
let output = array![
|
|
371
|
-
[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
|
|
372
|
-
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
|
|
373
|
-
];
|
|
325
|
+
let output = array![[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]];
|
|
374
326
|
let expected = array![[3.0f32, 4.0], [9.0, 10.0]];
|
|
375
|
-
let extracted =
|
|
376
|
-
output.view().into_dyn(),
|
|
377
|
-
|
|
378
|
-
&test_config(ExtractorMode::Token(1)),
|
|
379
|
-
)
|
|
380
|
-
.unwrap();
|
|
327
|
+
let extracted =
|
|
328
|
+
extract_embeddings(output.view().into_dyn(), empty_attention_mask(), &test_config(ExtractorMode::Token(1)))
|
|
329
|
+
.unwrap();
|
|
381
330
|
|
|
382
331
|
assert_eq!(extracted, expected);
|
|
383
332
|
}
|
|
384
333
|
|
|
385
334
|
#[test]
|
|
386
335
|
fn extract_embeddings_mean_pool_uses_output_view_and_attention_mask() {
|
|
387
|
-
let output = array![
|
|
388
|
-
[[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]],
|
|
389
|
-
[[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]
|
|
390
|
-
];
|
|
336
|
+
let output = array![[[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]], [[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]];
|
|
391
337
|
let attention_mask = array![[1_i64, 1, 0], [0, 1, 1]];
|
|
392
338
|
let expected = array![[3.0f32, 5.0], [8.0, 10.0]];
|
|
393
|
-
let extracted =
|
|
394
|
-
output.view().into_dyn(),
|
|
395
|
-
|
|
396
|
-
&test_config(ExtractorMode::MeanPool),
|
|
397
|
-
)
|
|
398
|
-
.unwrap();
|
|
339
|
+
let extracted =
|
|
340
|
+
extract_embeddings(output.view().into_dyn(), attention_mask.view(), &test_config(ExtractorMode::MeanPool))
|
|
341
|
+
.unwrap();
|
|
399
342
|
|
|
400
343
|
assert_eq!(extracted, expected);
|
|
401
344
|
}
|