gte 0.0.15 → 0.0.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile +0 -1
- data/README.md +112 -82
- data/Rakefile +0 -9
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +1 -1
- data/ext/gte/src/embedder.rs +29 -65
- data/ext/gte/src/lib.rs +1 -0
- data/ext/gte/src/model_config.rs +0 -4
- data/ext/gte/src/pipeline.rs +8 -9
- data/ext/gte/src/postprocess.rs +8 -6
- data/ext/gte/src/reranker.rs +7 -10
- data/ext/gte/src/ruby_embedder.rs +10 -33
- data/ext/gte/src/session.rs +50 -156
- data/ext/gte/src/tokenizer.rs +45 -38
- data/ext/gte/tests/embedder_unit_test.rs +1 -1
- data/ext/gte/tests/padding_regression_test.rs +7 -25
- data/ext/gte/tests/tokenizer_unit_test.rs +7 -7
- data/lib/gte/config.rb +1 -2
- data/lib/gte/embedder.rb +2 -14
- data/lib/gte/model.rb +0 -7
- data/lib/gte/reranker.rb +14 -33
- data/lib/gte.rb +4 -25
- metadata +1 -1
data/ext/gte/src/reranker.rs
CHANGED
|
@@ -6,7 +6,7 @@ use crate::model_profile::{
|
|
|
6
6
|
};
|
|
7
7
|
use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
8
8
|
use crate::postprocess::sigmoid_scores;
|
|
9
|
-
use crate::session::{build_session, SessionPool};
|
|
9
|
+
use crate::session::{build_session, resolve_pool_size, SessionPool};
|
|
10
10
|
use crate::tokenizer::{parse_padding_mode_override, Tokenizer};
|
|
11
11
|
use std::path::{Path, PathBuf};
|
|
12
12
|
|
|
@@ -54,8 +54,6 @@ impl Reranker {
|
|
|
54
54
|
with_attention_mask: true,
|
|
55
55
|
optimization_level,
|
|
56
56
|
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
57
|
-
lowercase_input: false,
|
|
58
|
-
max_input_chars: None,
|
|
59
57
|
};
|
|
60
58
|
let session = build_session(&model_path, &probe_config)?;
|
|
61
59
|
|
|
@@ -83,10 +81,8 @@ impl Reranker {
|
|
|
83
81
|
with_attention_mask: config.with_attention_mask,
|
|
84
82
|
optimization_level,
|
|
85
83
|
execution_providers: None,
|
|
86
|
-
lowercase_input: false,
|
|
87
|
-
max_input_chars: None,
|
|
88
84
|
};
|
|
89
|
-
let pool = SessionPool::new(
|
|
85
|
+
let pool = SessionPool::new(&model_path, &model_config, resolve_pool_size())?;
|
|
90
86
|
Ok(Self { tokenizer, pool, config })
|
|
91
87
|
}
|
|
92
88
|
|
|
@@ -102,13 +98,12 @@ impl Reranker {
|
|
|
102
98
|
|
|
103
99
|
fn score_tokenized(&self, tokenized: &crate::tokenizer::Tokenized, apply_sigmoid: bool) -> Result<Vec<f32>> {
|
|
104
100
|
let input_tensors = InputTensors::from_tokenized(tokenized, self.config.with_attention_mask)?;
|
|
105
|
-
let output_name = self.config.output_tensor.clone();
|
|
106
101
|
let inputs = input_tensors.inputs;
|
|
107
102
|
|
|
108
103
|
self.pool.with_session(|session| {
|
|
109
104
|
let outputs = session.run(inputs).map_err(|e| GteError::Ort(e.to_string()))?;
|
|
110
105
|
|
|
111
|
-
let array = extract_output_tensor(&outputs,
|
|
106
|
+
let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
|
|
112
107
|
|
|
113
108
|
let mut scores = match array.ndim() {
|
|
114
109
|
1 => array.into_dimensionality::<ndarray::Ix1>()?.to_vec(),
|
|
@@ -116,14 +111,16 @@ impl Reranker {
|
|
|
116
111
|
let shape = array.shape();
|
|
117
112
|
if shape[1] == 0 {
|
|
118
113
|
return Err(GteError::Inference(format!(
|
|
119
|
-
"reranker output '{
|
|
114
|
+
"reranker output '{}' has invalid shape {shape:?}",
|
|
115
|
+
self.config.output_tensor
|
|
120
116
|
)));
|
|
121
117
|
}
|
|
122
118
|
array.slice(ndarray::s![.., 0]).to_vec()
|
|
123
119
|
}
|
|
124
120
|
n => {
|
|
125
121
|
return Err(GteError::Inference(format!(
|
|
126
|
-
"reranker output '{
|
|
122
|
+
"reranker output '{}' rank {n} is unsupported; expected rank 1 or 2",
|
|
123
|
+
self.config.output_tensor
|
|
127
124
|
)))
|
|
128
125
|
}
|
|
129
126
|
};
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
#![allow(unused_results)]
|
|
4
4
|
#![allow(unused_qualifications)]
|
|
5
5
|
|
|
6
|
-
use crate::embedder::
|
|
6
|
+
use crate::embedder::Embedder;
|
|
7
7
|
use crate::error::GteError;
|
|
8
8
|
use crate::model_config::ModelLoadOverrides;
|
|
9
9
|
use crate::reranker::Reranker;
|
|
@@ -15,7 +15,6 @@ use std::sync::Arc;
|
|
|
15
15
|
#[wrap(class = "GTE::Embedder", free_immediately, size)]
|
|
16
16
|
pub struct RbEmbedder {
|
|
17
17
|
inner: Arc<Embedder>,
|
|
18
|
-
normalize: bool,
|
|
19
18
|
}
|
|
20
19
|
|
|
21
20
|
#[wrap(class = "GTE::Reranker", free_immediately, size)]
|
|
@@ -38,7 +37,6 @@ pub struct RbTensor {
|
|
|
38
37
|
struct InferArgs {
|
|
39
38
|
embedder: *const Embedder,
|
|
40
39
|
texts: *const Vec<String>,
|
|
41
|
-
normalize: bool,
|
|
42
40
|
result: Option<crate::error::Result<ndarray::Array2<f32>>>,
|
|
43
41
|
}
|
|
44
42
|
|
|
@@ -66,15 +64,7 @@ fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
|
|
|
66
64
|
|
|
67
65
|
unsafe extern "C" fn run_embed_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
68
66
|
let args = &mut *(ptr as *mut InferArgs);
|
|
69
|
-
let run_result = catch_unwind(AssertUnwindSafe(||
|
|
70
|
-
// Full embedding path (tokenization + inference) runs without the GVL.
|
|
71
|
-
let embeddings = (*args.embedder).embed_ref(&*args.texts)?;
|
|
72
|
-
if args.normalize {
|
|
73
|
-
Ok(normalize_l2(embeddings))
|
|
74
|
-
} else {
|
|
75
|
-
Ok(embeddings)
|
|
76
|
-
}
|
|
77
|
-
}));
|
|
67
|
+
let run_result = catch_unwind(AssertUnwindSafe(|| (*args.embedder).embed(&*args.texts)));
|
|
78
68
|
args.result = Some(match run_result {
|
|
79
69
|
Ok(result) => result,
|
|
80
70
|
Err(payload) => {
|
|
@@ -97,14 +87,9 @@ unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
|
97
87
|
std::ptr::null_mut()
|
|
98
88
|
}
|
|
99
89
|
|
|
100
|
-
fn infer_without_gvl(
|
|
101
|
-
embedder: &Arc<Embedder>,
|
|
102
|
-
normalize: bool,
|
|
103
|
-
texts: Vec<String>,
|
|
104
|
-
) -> Result<ndarray::Array2<f32>, Error> {
|
|
90
|
+
fn infer_without_gvl(embedder: &Arc<Embedder>, texts: Vec<String>) -> Result<ndarray::Array2<f32>, Error> {
|
|
105
91
|
let embeddings = unsafe {
|
|
106
|
-
let mut args =
|
|
107
|
-
InferArgs { embedder: Arc::as_ptr(embedder), texts: &texts as *const Vec<String>, normalize, result: None };
|
|
92
|
+
let mut args = InferArgs { embedder: Arc::as_ptr(embedder), texts: &texts as *const Vec<String>, result: None };
|
|
108
93
|
rb_sys::rb_thread_call_without_gvl(
|
|
109
94
|
Some(run_embed_without_gvl),
|
|
110
95
|
&mut args as *mut InferArgs as *mut c_void,
|
|
@@ -167,13 +152,10 @@ impl RbEmbedder {
|
|
|
167
152
|
dir_path: String,
|
|
168
153
|
optimization_level: u8,
|
|
169
154
|
model_name: String,
|
|
170
|
-
normalize: bool,
|
|
171
155
|
output_tensor: String,
|
|
172
156
|
max_length: usize,
|
|
173
157
|
padding: String,
|
|
174
158
|
execution_providers: String,
|
|
175
|
-
lowercase_input: bool,
|
|
176
|
-
max_input_chars: usize,
|
|
177
159
|
) -> Result<Self, Error> {
|
|
178
160
|
let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
|
|
179
161
|
let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
|
|
@@ -181,29 +163,26 @@ impl RbEmbedder {
|
|
|
181
163
|
let execution_providers_override =
|
|
182
164
|
if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
|
|
183
165
|
let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
|
|
184
|
-
let max_input_chars_override = if max_input_chars == 0 { None } else { Some(max_input_chars) };
|
|
185
166
|
let overrides = ModelLoadOverrides {
|
|
186
167
|
model_name: name,
|
|
187
168
|
output_tensor: output_override,
|
|
188
169
|
max_length: max_length_override,
|
|
189
170
|
padding: padding_override,
|
|
190
171
|
execution_providers: execution_providers_override,
|
|
191
|
-
|
|
192
|
-
max_input_chars: max_input_chars_override,
|
|
172
|
+
..ModelLoadOverrides::default()
|
|
193
173
|
};
|
|
194
174
|
let embedder = Embedder::from_dir(&dir_path, optimization_level, overrides).map_err(magnus::Error::from)?;
|
|
195
|
-
|
|
196
|
-
Ok(RbEmbedder { inner: Arc::new(embedder), normalize: normalize && !skip_normalize })
|
|
175
|
+
Ok(RbEmbedder { inner: Arc::new(embedder) })
|
|
197
176
|
}
|
|
198
177
|
|
|
199
178
|
pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
|
|
200
179
|
let texts: Vec<String> = texts.to_vec()?;
|
|
201
|
-
let embeddings = infer_without_gvl(&rb_self.inner,
|
|
180
|
+
let embeddings = infer_without_gvl(&rb_self.inner, texts)?;
|
|
202
181
|
tensor_from_array(embeddings)
|
|
203
182
|
}
|
|
204
183
|
|
|
205
184
|
pub fn rb_embed_one(_ruby: &Ruby, rb_self: &Self, text: String) -> Result<RbTensor, Error> {
|
|
206
|
-
let embeddings = infer_without_gvl(&rb_self.inner,
|
|
185
|
+
let embeddings = infer_without_gvl(&rb_self.inner, vec![text])?;
|
|
207
186
|
tensor_from_array(embeddings)
|
|
208
187
|
}
|
|
209
188
|
}
|
|
@@ -219,8 +198,6 @@ impl RbReranker {
|
|
|
219
198
|
max_length: usize,
|
|
220
199
|
padding: String,
|
|
221
200
|
execution_providers: String,
|
|
222
|
-
_lowercase_input: bool,
|
|
223
|
-
_max_input_chars: usize,
|
|
224
201
|
) -> Result<Self, Error> {
|
|
225
202
|
let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
|
|
226
203
|
let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
|
|
@@ -335,12 +312,12 @@ impl RbTensor {
|
|
|
335
312
|
pub fn register(ruby: &Ruby) -> Result<(), Error> {
|
|
336
313
|
let module = ruby.define_module("GTE")?;
|
|
337
314
|
let embedder_class = module.define_class("Embedder", ruby.class_object())?;
|
|
338
|
-
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new,
|
|
315
|
+
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 7))?;
|
|
339
316
|
embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
|
|
340
317
|
embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
|
|
341
318
|
|
|
342
319
|
let reranker_class = module.define_class("Reranker", ruby.class_object())?;
|
|
343
|
-
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new,
|
|
320
|
+
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 8))?;
|
|
344
321
|
reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
|
|
345
322
|
|
|
346
323
|
let tensor_class = module.define_class("Tensor", ruby.class_object())?;
|
data/ext/gte/src/session.rs
CHANGED
|
@@ -3,132 +3,58 @@ use crate::model_config::{ExtractorMode, ModelConfig};
|
|
|
3
3
|
use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
4
4
|
use crate::postprocess::mean_pool;
|
|
5
5
|
use crate::tokenizer::Tokenized;
|
|
6
|
-
use ndarray::{Array2,
|
|
6
|
+
use ndarray::{Array2, ArrayViewD, Ix2};
|
|
7
7
|
use ort::execution_providers::{CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider};
|
|
8
8
|
use ort::session::{OutputSelector, RunOptions, Session};
|
|
9
9
|
use parking_lot::Mutex;
|
|
10
|
-
use std::path::
|
|
10
|
+
use std::path::Path;
|
|
11
11
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
12
|
-
use std::sync::Arc;
|
|
13
12
|
|
|
14
|
-
|
|
15
|
-
// Lazy session pool — starts with 1 session, grows on contention, capped.
|
|
16
|
-
//
|
|
17
|
-
// Pool max is resolved in order:
|
|
18
|
-
// 1. GTE_SESSION_POOL_SIZE env var (explicit override)
|
|
19
|
-
// 2. Auto: 2 (conservative: 2× pure Ruby memory at peak, no OOM risk)
|
|
20
|
-
//
|
|
21
|
-
// At idle the pool holds 1 session (same memory as pure Ruby's single
|
|
22
|
-
// OnnxRuntime::Model). When all existing sessions are busy and the cap
|
|
23
|
-
// hasn't been reached, a new session is created on-demand.
|
|
24
|
-
// ---------------------------------------------------------------------------
|
|
25
|
-
|
|
26
|
-
fn resolve_pool_cap() -> usize {
|
|
13
|
+
pub(crate) fn resolve_pool_size() -> usize {
|
|
27
14
|
if let Some(n) =
|
|
28
15
|
std::env::var("GTE_SESSION_POOL_SIZE").ok().and_then(|v| v.trim().parse::<usize>().ok()).filter(|&n| n > 0)
|
|
29
16
|
{
|
|
30
17
|
return n;
|
|
31
18
|
}
|
|
32
|
-
2
|
|
19
|
+
let cpus = std::thread::available_parallelism().map(std::num::NonZero::get).unwrap_or(2);
|
|
20
|
+
cpus.min(4).max(1)
|
|
33
21
|
}
|
|
34
22
|
|
|
35
23
|
pub struct SessionPool {
|
|
36
|
-
|
|
24
|
+
sessions: Vec<Mutex<Session>>,
|
|
37
25
|
next_idx: AtomicUsize,
|
|
38
|
-
cap: usize,
|
|
39
|
-
}
|
|
40
|
-
|
|
41
|
-
struct PoolInner {
|
|
42
|
-
sessions: Vec<Arc<Mutex<Session>>>,
|
|
43
|
-
model_path: PathBuf,
|
|
44
|
-
build_config: ModelConfig,
|
|
45
26
|
}
|
|
46
27
|
|
|
47
28
|
impl SessionPool {
|
|
48
|
-
pub fn new(
|
|
49
|
-
let
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
build_config: build_config.clone(),
|
|
57
|
-
}),
|
|
58
|
-
next_idx: AtomicUsize::new(0),
|
|
59
|
-
cap,
|
|
60
|
-
})
|
|
61
|
-
}
|
|
62
|
-
|
|
63
|
-
pub fn run(&self, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
|
|
64
|
-
self.with_session(|session| run_session(session, tokenized, config))
|
|
29
|
+
pub fn new(model_path: &Path, config: &ModelConfig, pool_size: usize) -> Result<Self> {
|
|
30
|
+
let sessions = (0..pool_size)
|
|
31
|
+
.map(|_| build_session(model_path, config))
|
|
32
|
+
.collect::<Result<Vec<_>>>()?
|
|
33
|
+
.into_iter()
|
|
34
|
+
.map(Mutex::new)
|
|
35
|
+
.collect();
|
|
36
|
+
Ok(Self { sessions, next_idx: AtomicUsize::new(0) })
|
|
65
37
|
}
|
|
66
38
|
|
|
67
39
|
pub fn with_session<F, R>(&self, f: F) -> Result<R>
|
|
68
40
|
where
|
|
69
41
|
F: FnOnce(&mut Session) -> Result<R>,
|
|
70
42
|
{
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
};
|
|
80
|
-
let len = arcs.len();
|
|
81
|
-
let start = self.next_idx.fetch_add(1, Ordering::Relaxed) % len;
|
|
82
|
-
|
|
83
|
-
for offset in 0..len {
|
|
84
|
-
let idx = (start + offset) % len;
|
|
85
|
-
if let Some(mut guard) = arcs[idx].try_lock() {
|
|
86
|
-
return f(&mut guard);
|
|
87
|
-
}
|
|
88
|
-
}
|
|
89
|
-
|
|
90
|
-
// All sessions busy — try to grow the pool
|
|
91
|
-
let grew = {
|
|
92
|
-
let mut inner = self.inner.lock();
|
|
93
|
-
if inner.sessions.len() < self.cap {
|
|
94
|
-
match build_session(&inner.model_path, &inner.build_config) {
|
|
95
|
-
Ok(session) => {
|
|
96
|
-
inner.sessions.push(Arc::new(Mutex::new(session)));
|
|
97
|
-
true
|
|
98
|
-
}
|
|
99
|
-
Err(e) => return Err(e),
|
|
100
|
-
}
|
|
101
|
-
} else {
|
|
102
|
-
false
|
|
103
|
-
}
|
|
104
|
-
};
|
|
105
|
-
|
|
106
|
-
if grew {
|
|
107
|
-
continue;
|
|
108
|
-
}
|
|
109
|
-
|
|
110
|
-
// At cap — spin briefly, then block on a session
|
|
111
|
-
let idx = self.next_idx.fetch_add(1, Ordering::Relaxed) % len;
|
|
112
|
-
let arc = Arc::clone(&arcs[idx]);
|
|
113
|
-
|
|
114
|
-
for _ in 0..SPIN_LIMIT {
|
|
115
|
-
if let Some(mut guard) = arc.try_lock() {
|
|
116
|
-
return f(&mut guard);
|
|
117
|
-
}
|
|
118
|
-
std::hint::spin_loop();
|
|
119
|
-
}
|
|
43
|
+
let idx = if self.sessions.len() == 1 {
|
|
44
|
+
0
|
|
45
|
+
} else {
|
|
46
|
+
self.next_idx.fetch_add(1, Ordering::Relaxed) % self.sessions.len()
|
|
47
|
+
};
|
|
48
|
+
let mut session = self.sessions[idx].lock();
|
|
49
|
+
f(&mut session)
|
|
50
|
+
}
|
|
120
51
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
}
|
|
52
|
+
pub fn len(&self) -> usize {
|
|
53
|
+
self.sessions.len()
|
|
124
54
|
}
|
|
125
55
|
}
|
|
126
56
|
|
|
127
|
-
|
|
128
|
-
// Session construction
|
|
129
|
-
// ---------------------------------------------------------------------------
|
|
130
|
-
|
|
131
|
-
pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
|
|
57
|
+
pub(crate) fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
|
|
132
58
|
fn ort_err(e: impl std::fmt::Display) -> GteError {
|
|
133
59
|
GteError::Ort(e.to_string())
|
|
134
60
|
}
|
|
@@ -164,10 +90,14 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
|
|
|
164
90
|
}
|
|
165
91
|
|
|
166
92
|
fn auto_detect_providers() -> Vec<ExecutionProviderDispatch> {
|
|
167
|
-
let mut providers = Vec::new();
|
|
168
93
|
#[cfg(target_arch = "aarch64")]
|
|
169
|
-
|
|
170
|
-
|
|
94
|
+
{
|
|
95
|
+
vec![XNNPACKExecutionProvider::default().build().fail_silently()]
|
|
96
|
+
}
|
|
97
|
+
#[cfg(not(target_arch = "aarch64"))]
|
|
98
|
+
{
|
|
99
|
+
Vec::new()
|
|
100
|
+
}
|
|
171
101
|
}
|
|
172
102
|
|
|
173
103
|
fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
|
|
@@ -193,11 +123,7 @@ fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionP
|
|
|
193
123
|
providers
|
|
194
124
|
}
|
|
195
125
|
|
|
196
|
-
|
|
197
|
-
// Run a single inference
|
|
198
|
-
// ---------------------------------------------------------------------------
|
|
199
|
-
|
|
200
|
-
pub fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
|
|
126
|
+
pub(crate) fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
|
|
201
127
|
let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
|
|
202
128
|
let run_opts = RunOptions::new()
|
|
203
129
|
.map_err(|e| GteError::Ort(e.to_string()))?
|
|
@@ -211,7 +137,7 @@ pub fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelC
|
|
|
211
137
|
|
|
212
138
|
fn extract_embeddings(
|
|
213
139
|
array: ArrayViewD<'_, f32>,
|
|
214
|
-
attention_mask: ArrayView2<'_, i64>,
|
|
140
|
+
attention_mask: ndarray::ArrayView2<'_, i64>,
|
|
215
141
|
config: &ModelConfig,
|
|
216
142
|
) -> Result<Array2<f32>> {
|
|
217
143
|
match config.mode {
|
|
@@ -244,21 +170,6 @@ mod tests {
|
|
|
244
170
|
|
|
245
171
|
use super::extract_embeddings;
|
|
246
172
|
|
|
247
|
-
fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
|
|
248
|
-
order_override.or(env_order).unwrap_or("cpu").to_ascii_lowercase()
|
|
249
|
-
}
|
|
250
|
-
|
|
251
|
-
fn parse_provider_registrations(order: &str) -> Vec<&str> {
|
|
252
|
-
let mut providers = Vec::new();
|
|
253
|
-
for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
|
|
254
|
-
match provider {
|
|
255
|
-
"xnnpack" | "coreml" => providers.push(provider),
|
|
256
|
-
_ => {}
|
|
257
|
-
}
|
|
258
|
-
}
|
|
259
|
-
providers
|
|
260
|
-
}
|
|
261
|
-
|
|
262
173
|
fn test_config(mode: ExtractorMode) -> ModelConfig {
|
|
263
174
|
ModelConfig {
|
|
264
175
|
max_length: 8,
|
|
@@ -269,8 +180,6 @@ mod tests {
|
|
|
269
180
|
with_attention_mask: true,
|
|
270
181
|
optimization_level: 3,
|
|
271
182
|
execution_providers: None,
|
|
272
|
-
lowercase_input: false,
|
|
273
|
-
max_input_chars: None,
|
|
274
183
|
}
|
|
275
184
|
}
|
|
276
185
|
|
|
@@ -279,37 +188,6 @@ mod tests {
|
|
|
279
188
|
ArrayView2::from_shape((0, 0), &EMPTY).unwrap()
|
|
280
189
|
}
|
|
281
190
|
|
|
282
|
-
#[test]
|
|
283
|
-
fn parse_provider_registrations_keeps_supported_order() {
|
|
284
|
-
let parsed = parse_provider_registrations("xnnpack,coreml");
|
|
285
|
-
assert_eq!(parsed, vec!["xnnpack", "coreml"]);
|
|
286
|
-
}
|
|
287
|
-
|
|
288
|
-
#[test]
|
|
289
|
-
fn parse_provider_registrations_treats_cpu_and_none_as_fallback() {
|
|
290
|
-
assert!(parse_provider_registrations("cpu").is_empty());
|
|
291
|
-
assert!(parse_provider_registrations("none").is_empty());
|
|
292
|
-
assert!(parse_provider_registrations("none,cpu").is_empty());
|
|
293
|
-
}
|
|
294
|
-
|
|
295
|
-
#[test]
|
|
296
|
-
fn parse_provider_registrations_ignores_unknowns_and_empties() {
|
|
297
|
-
let parsed = parse_provider_registrations(" ,xnnpak,,xnnpack,unknown,coreml,");
|
|
298
|
-
assert_eq!(parsed, vec!["xnnpack", "coreml"]);
|
|
299
|
-
}
|
|
300
|
-
|
|
301
|
-
#[test]
|
|
302
|
-
fn resolve_provider_order_prefers_override() {
|
|
303
|
-
assert_eq!(resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")), "xnnpack");
|
|
304
|
-
assert_eq!(resolve_provider_order_with_env(Some("CPU"), None), "cpu");
|
|
305
|
-
}
|
|
306
|
-
|
|
307
|
-
#[test]
|
|
308
|
-
fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
|
|
309
|
-
assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
|
|
310
|
-
assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
|
|
311
|
-
}
|
|
312
|
-
|
|
313
191
|
#[test]
|
|
314
192
|
fn extract_embeddings_raw_copies_only_final_matrix() {
|
|
315
193
|
let output = array![[1.0f32, 2.0], [3.0, 4.0]];
|
|
@@ -342,4 +220,20 @@ mod tests {
|
|
|
342
220
|
|
|
343
221
|
assert_eq!(extracted, expected);
|
|
344
222
|
}
|
|
223
|
+
|
|
224
|
+
#[test]
|
|
225
|
+
fn resolve_pool_size_uses_env_var() {
|
|
226
|
+
std::env::set_var("GTE_SESSION_POOL_SIZE", "16");
|
|
227
|
+
let size = super::resolve_pool_size();
|
|
228
|
+
assert_eq!(size, 16);
|
|
229
|
+
std::env::remove_var("GTE_SESSION_POOL_SIZE");
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
#[test]
|
|
233
|
+
fn resolve_pool_size_defaults_to_cpu_count_capped_at_4() {
|
|
234
|
+
// Without GTE_SESSION_POOL_SIZE, the default is min(available_parallelism, 4).max(1).
|
|
235
|
+
// On any machine with >= 1 CPU, this should return between 1 and 4.
|
|
236
|
+
let size = super::resolve_pool_size();
|
|
237
|
+
assert!((1..=4).contains(&size), "expected 1-4, got {size}");
|
|
238
|
+
}
|
|
345
239
|
}
|
data/ext/gte/src/tokenizer.rs
CHANGED
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
2
|
use crate::model_config::PaddingMode;
|
|
3
|
+
use ndarray::Array2;
|
|
3
4
|
use std::path::Path;
|
|
4
5
|
use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
|
|
5
6
|
|
|
6
7
|
pub struct Tokenized {
|
|
7
|
-
pub
|
|
8
|
-
pub
|
|
9
|
-
pub
|
|
10
|
-
pub attn_masks: Vec<i64>,
|
|
11
|
-
pub type_ids: Option<Vec<i64>>,
|
|
8
|
+
pub input_ids: Array2<i64>,
|
|
9
|
+
pub attn_masks: Array2<i64>,
|
|
10
|
+
pub type_ids: Option<Array2<i64>>,
|
|
12
11
|
}
|
|
13
12
|
|
|
14
13
|
pub struct Tokenizer {
|
|
@@ -24,7 +23,6 @@ impl Tokenizer {
|
|
|
24
23
|
padding_mode: PaddingMode,
|
|
25
24
|
fixed_padding_length: Option<usize>,
|
|
26
25
|
) -> Result<Self> {
|
|
27
|
-
#[allow(unused_results)]
|
|
28
26
|
{
|
|
29
27
|
let mut tokenizer =
|
|
30
28
|
tokenizers::Tokenizer::from_file(tokenizer_path).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
@@ -34,41 +32,59 @@ impl Tokenizer {
|
|
|
34
32
|
strategy: resolve_padding_strategy(padding_mode, max_length, fixed_padding_length),
|
|
35
33
|
..Default::default()
|
|
36
34
|
};
|
|
37
|
-
tokenizer.with_truncation(Some(truncation)).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
38
|
-
tokenizer.with_padding(Some(padding));
|
|
35
|
+
let _ = tokenizer.with_truncation(Some(truncation)).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
36
|
+
let _ = tokenizer.with_padding(Some(padding));
|
|
39
37
|
|
|
40
38
|
Ok(Self { tokenizer, with_type_ids })
|
|
41
39
|
}
|
|
42
40
|
}
|
|
43
41
|
|
|
44
42
|
pub fn tokenize(&self, texts: &[String]) -> Result<Tokenized> {
|
|
45
|
-
if texts.
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
43
|
+
if texts.is_empty() {
|
|
44
|
+
return Ok(Tokenized {
|
|
45
|
+
input_ids: Array2::zeros((0, 0)),
|
|
46
|
+
attn_masks: Array2::zeros((0, 0)),
|
|
47
|
+
type_ids: None,
|
|
48
|
+
});
|
|
49
49
|
}
|
|
50
50
|
|
|
51
51
|
let encode_inputs: Vec<&str> = texts.iter().map(String::as_str).collect();
|
|
52
52
|
let encodings =
|
|
53
53
|
self.tokenizer.encode_batch_fast(encode_inputs, true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
54
54
|
|
|
55
|
-
|
|
55
|
+
build_tokenized(&encodings, self.with_type_ids)
|
|
56
56
|
}
|
|
57
57
|
|
|
58
58
|
pub fn tokenize_pairs(&self, pairs: &[(String, String)]) -> Result<Tokenized> {
|
|
59
|
+
if pairs.is_empty() {
|
|
60
|
+
return Ok(Tokenized {
|
|
61
|
+
input_ids: Array2::zeros((0, 0)),
|
|
62
|
+
attn_masks: Array2::zeros((0, 0)),
|
|
63
|
+
type_ids: None,
|
|
64
|
+
});
|
|
65
|
+
}
|
|
66
|
+
|
|
59
67
|
let encode_inputs: Vec<tokenizers::EncodeInput<'_>> =
|
|
60
68
|
pairs.iter().map(|(left, right)| (left.as_str(), right.as_str()).into()).collect();
|
|
61
69
|
let encodings =
|
|
62
70
|
self.tokenizer.encode_batch_fast(encode_inputs, true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
63
|
-
|
|
71
|
+
build_tokenized(&encodings, self.with_type_ids)
|
|
64
72
|
}
|
|
65
73
|
|
|
66
74
|
pub fn tokenize_query_candidates(&self, query: &str, candidates: &[String]) -> Result<Tokenized> {
|
|
75
|
+
if candidates.is_empty() {
|
|
76
|
+
return Ok(Tokenized {
|
|
77
|
+
input_ids: Array2::zeros((0, 0)),
|
|
78
|
+
attn_masks: Array2::zeros((0, 0)),
|
|
79
|
+
type_ids: None,
|
|
80
|
+
});
|
|
81
|
+
}
|
|
82
|
+
|
|
67
83
|
let encode_inputs: Vec<tokenizers::EncodeInput<'_>> =
|
|
68
84
|
candidates.iter().map(|candidate| (query, candidate.as_str()).into()).collect();
|
|
69
85
|
let encodings =
|
|
70
86
|
self.tokenizer.encode_batch_fast(encode_inputs, true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
71
|
-
|
|
87
|
+
build_tokenized(&encodings, self.with_type_ids)
|
|
72
88
|
}
|
|
73
89
|
}
|
|
74
90
|
|
|
@@ -102,36 +118,30 @@ fn resolve_padding_strategy(
|
|
|
102
118
|
}
|
|
103
119
|
}
|
|
104
120
|
|
|
105
|
-
fn
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&v| i64::from(v)).collect();
|
|
109
|
-
let attn_masks: Vec<i64> = encoding.get_attention_mask().iter().map(|&v| i64::from(v)).collect();
|
|
110
|
-
let type_ids: Option<Vec<i64>> =
|
|
111
|
-
with_type_ids.then(|| encoding.get_type_ids().iter().map(|&v| i64::from(v)).collect());
|
|
112
|
-
|
|
113
|
-
Tokenized { rows: 1, cols, input_ids, attn_masks, type_ids }
|
|
121
|
+
fn to_i64(array: &[u32]) -> Vec<i64> {
|
|
122
|
+
array.iter().map(|&v| v as i64).collect()
|
|
114
123
|
}
|
|
115
124
|
|
|
116
|
-
fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> Tokenized {
|
|
125
|
+
fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> Result<Tokenized> {
|
|
117
126
|
let rows = encodings.len();
|
|
118
127
|
let cols = encodings.first().map_or(0, tokenizers::Encoding::len);
|
|
119
|
-
|
|
128
|
+
if rows == 0 || cols == 0 {
|
|
129
|
+
return Ok(Tokenized { input_ids: Array2::zeros((0, 0)), attn_masks: Array2::zeros((0, 0)), type_ids: None });
|
|
130
|
+
}
|
|
120
131
|
|
|
121
|
-
let mut input_ids =
|
|
122
|
-
let mut attn_masks =
|
|
123
|
-
let mut type_ids = with_type_ids.then(||
|
|
132
|
+
let mut input_ids = Array2::zeros((0, cols));
|
|
133
|
+
let mut attn_masks = Array2::zeros((0, cols));
|
|
134
|
+
let mut type_ids = with_type_ids.then(|| Array2::zeros((0, cols)));
|
|
124
135
|
|
|
125
136
|
for encoding in encodings {
|
|
126
|
-
input_ids.
|
|
127
|
-
attn_masks.
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
type_ids.extend(encoding.get_type_ids().iter().map(|&v| i64::from(v)));
|
|
137
|
+
input_ids.push_row(ndarray::ArrayView::from(&to_i64(encoding.get_ids())))?;
|
|
138
|
+
attn_masks.push_row(ndarray::ArrayView::from(&to_i64(encoding.get_attention_mask())))?;
|
|
139
|
+
if let Some(ref mut type_ids) = type_ids {
|
|
140
|
+
type_ids.push_row(ndarray::ArrayView::from(&to_i64(encoding.get_type_ids())))?;
|
|
131
141
|
}
|
|
132
142
|
}
|
|
133
143
|
|
|
134
|
-
Tokenized {
|
|
144
|
+
Ok(Tokenized { input_ids, attn_masks, type_ids })
|
|
135
145
|
}
|
|
136
146
|
|
|
137
147
|
#[cfg(test)]
|
|
@@ -154,9 +164,6 @@ mod tests {
|
|
|
154
164
|
|
|
155
165
|
#[test]
|
|
156
166
|
fn resolve_padding_strategy_auto_always_uses_batch_longest() {
|
|
157
|
-
// Auto ignores fixed_padding_length from tokenizer.json — BatchLongest is
|
|
158
|
-
// always faster for inference and correct for variable-length inputs.
|
|
159
|
-
// Use PaddingMode::Fixed explicitly when fixed-length padding is required.
|
|
160
167
|
assert!(matches!(resolve_padding_strategy(PaddingMode::Auto, 64, Some(64)), PaddingStrategy::BatchLongest));
|
|
161
168
|
assert!(matches!(resolve_padding_strategy(PaddingMode::Auto, 512, None), PaddingStrategy::BatchLongest));
|
|
162
169
|
}
|