gte 0.0.14-aarch64-linux → 0.0.16-aarch64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile +0 -1
- data/README.md +112 -82
- data/Rakefile +0 -9
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +2 -1
- data/ext/gte/src/embedder.rs +29 -65
- data/ext/gte/src/lib.rs +1 -0
- data/ext/gte/src/model_config.rs +0 -4
- data/ext/gte/src/pipeline.rs +8 -9
- data/ext/gte/src/postprocess.rs +8 -6
- data/ext/gte/src/reranker.rs +7 -10
- data/ext/gte/src/ruby_embedder.rs +10 -33
- data/ext/gte/src/session.rs +58 -109
- data/ext/gte/src/tokenizer.rs +45 -38
- data/ext/gte/tests/embedder_unit_test.rs +1 -1
- data/ext/gte/tests/padding_regression_test.rs +7 -25
- data/ext/gte/tests/tokenizer_unit_test.rs +7 -7
- data/lib/gte/config.rb +1 -2
- data/lib/gte/embedder.rb +2 -14
- data/lib/gte/gte.so +0 -0
- data/lib/gte/model.rb +0 -7
- data/lib/gte/reranker.rb +14 -33
- data/lib/gte.rb +4 -25
- metadata +2 -2
data/ext/gte/src/reranker.rs
CHANGED
|
@@ -6,7 +6,7 @@ use crate::model_profile::{
|
|
|
6
6
|
};
|
|
7
7
|
use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
8
8
|
use crate::postprocess::sigmoid_scores;
|
|
9
|
-
use crate::session::{build_session, SessionPool};
|
|
9
|
+
use crate::session::{build_session, resolve_pool_size, SessionPool};
|
|
10
10
|
use crate::tokenizer::{parse_padding_mode_override, Tokenizer};
|
|
11
11
|
use std::path::{Path, PathBuf};
|
|
12
12
|
|
|
@@ -54,8 +54,6 @@ impl Reranker {
|
|
|
54
54
|
with_attention_mask: true,
|
|
55
55
|
optimization_level,
|
|
56
56
|
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
57
|
-
lowercase_input: false,
|
|
58
|
-
max_input_chars: None,
|
|
59
57
|
};
|
|
60
58
|
let session = build_session(&model_path, &probe_config)?;
|
|
61
59
|
|
|
@@ -83,10 +81,8 @@ impl Reranker {
|
|
|
83
81
|
with_attention_mask: config.with_attention_mask,
|
|
84
82
|
optimization_level,
|
|
85
83
|
execution_providers: None,
|
|
86
|
-
lowercase_input: false,
|
|
87
|
-
max_input_chars: None,
|
|
88
84
|
};
|
|
89
|
-
let pool = SessionPool::new(
|
|
85
|
+
let pool = SessionPool::new(&model_path, &model_config, resolve_pool_size())?;
|
|
90
86
|
Ok(Self { tokenizer, pool, config })
|
|
91
87
|
}
|
|
92
88
|
|
|
@@ -102,13 +98,12 @@ impl Reranker {
|
|
|
102
98
|
|
|
103
99
|
fn score_tokenized(&self, tokenized: &crate::tokenizer::Tokenized, apply_sigmoid: bool) -> Result<Vec<f32>> {
|
|
104
100
|
let input_tensors = InputTensors::from_tokenized(tokenized, self.config.with_attention_mask)?;
|
|
105
|
-
let output_name = self.config.output_tensor.clone();
|
|
106
101
|
let inputs = input_tensors.inputs;
|
|
107
102
|
|
|
108
103
|
self.pool.with_session(|session| {
|
|
109
104
|
let outputs = session.run(inputs).map_err(|e| GteError::Ort(e.to_string()))?;
|
|
110
105
|
|
|
111
|
-
let array = extract_output_tensor(&outputs,
|
|
106
|
+
let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
|
|
112
107
|
|
|
113
108
|
let mut scores = match array.ndim() {
|
|
114
109
|
1 => array.into_dimensionality::<ndarray::Ix1>()?.to_vec(),
|
|
@@ -116,14 +111,16 @@ impl Reranker {
|
|
|
116
111
|
let shape = array.shape();
|
|
117
112
|
if shape[1] == 0 {
|
|
118
113
|
return Err(GteError::Inference(format!(
|
|
119
|
-
"reranker output '{
|
|
114
|
+
"reranker output '{}' has invalid shape {shape:?}",
|
|
115
|
+
self.config.output_tensor
|
|
120
116
|
)));
|
|
121
117
|
}
|
|
122
118
|
array.slice(ndarray::s![.., 0]).to_vec()
|
|
123
119
|
}
|
|
124
120
|
n => {
|
|
125
121
|
return Err(GteError::Inference(format!(
|
|
126
|
-
"reranker output '{
|
|
122
|
+
"reranker output '{}' rank {n} is unsupported; expected rank 1 or 2",
|
|
123
|
+
self.config.output_tensor
|
|
127
124
|
)))
|
|
128
125
|
}
|
|
129
126
|
};
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
#![allow(unused_results)]
|
|
4
4
|
#![allow(unused_qualifications)]
|
|
5
5
|
|
|
6
|
-
use crate::embedder::
|
|
6
|
+
use crate::embedder::Embedder;
|
|
7
7
|
use crate::error::GteError;
|
|
8
8
|
use crate::model_config::ModelLoadOverrides;
|
|
9
9
|
use crate::reranker::Reranker;
|
|
@@ -15,7 +15,6 @@ use std::sync::Arc;
|
|
|
15
15
|
#[wrap(class = "GTE::Embedder", free_immediately, size)]
|
|
16
16
|
pub struct RbEmbedder {
|
|
17
17
|
inner: Arc<Embedder>,
|
|
18
|
-
normalize: bool,
|
|
19
18
|
}
|
|
20
19
|
|
|
21
20
|
#[wrap(class = "GTE::Reranker", free_immediately, size)]
|
|
@@ -38,7 +37,6 @@ pub struct RbTensor {
|
|
|
38
37
|
struct InferArgs {
|
|
39
38
|
embedder: *const Embedder,
|
|
40
39
|
texts: *const Vec<String>,
|
|
41
|
-
normalize: bool,
|
|
42
40
|
result: Option<crate::error::Result<ndarray::Array2<f32>>>,
|
|
43
41
|
}
|
|
44
42
|
|
|
@@ -66,15 +64,7 @@ fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
|
|
|
66
64
|
|
|
67
65
|
unsafe extern "C" fn run_embed_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
68
66
|
let args = &mut *(ptr as *mut InferArgs);
|
|
69
|
-
let run_result = catch_unwind(AssertUnwindSafe(||
|
|
70
|
-
// Full embedding path (tokenization + inference) runs without the GVL.
|
|
71
|
-
let embeddings = (*args.embedder).embed_ref(&*args.texts)?;
|
|
72
|
-
if args.normalize {
|
|
73
|
-
Ok(normalize_l2(embeddings))
|
|
74
|
-
} else {
|
|
75
|
-
Ok(embeddings)
|
|
76
|
-
}
|
|
77
|
-
}));
|
|
67
|
+
let run_result = catch_unwind(AssertUnwindSafe(|| (*args.embedder).embed(&*args.texts)));
|
|
78
68
|
args.result = Some(match run_result {
|
|
79
69
|
Ok(result) => result,
|
|
80
70
|
Err(payload) => {
|
|
@@ -97,14 +87,9 @@ unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
|
97
87
|
std::ptr::null_mut()
|
|
98
88
|
}
|
|
99
89
|
|
|
100
|
-
fn infer_without_gvl(
|
|
101
|
-
embedder: &Arc<Embedder>,
|
|
102
|
-
normalize: bool,
|
|
103
|
-
texts: Vec<String>,
|
|
104
|
-
) -> Result<ndarray::Array2<f32>, Error> {
|
|
90
|
+
fn infer_without_gvl(embedder: &Arc<Embedder>, texts: Vec<String>) -> Result<ndarray::Array2<f32>, Error> {
|
|
105
91
|
let embeddings = unsafe {
|
|
106
|
-
let mut args =
|
|
107
|
-
InferArgs { embedder: Arc::as_ptr(embedder), texts: &texts as *const Vec<String>, normalize, result: None };
|
|
92
|
+
let mut args = InferArgs { embedder: Arc::as_ptr(embedder), texts: &texts as *const Vec<String>, result: None };
|
|
108
93
|
rb_sys::rb_thread_call_without_gvl(
|
|
109
94
|
Some(run_embed_without_gvl),
|
|
110
95
|
&mut args as *mut InferArgs as *mut c_void,
|
|
@@ -167,13 +152,10 @@ impl RbEmbedder {
|
|
|
167
152
|
dir_path: String,
|
|
168
153
|
optimization_level: u8,
|
|
169
154
|
model_name: String,
|
|
170
|
-
normalize: bool,
|
|
171
155
|
output_tensor: String,
|
|
172
156
|
max_length: usize,
|
|
173
157
|
padding: String,
|
|
174
158
|
execution_providers: String,
|
|
175
|
-
lowercase_input: bool,
|
|
176
|
-
max_input_chars: usize,
|
|
177
159
|
) -> Result<Self, Error> {
|
|
178
160
|
let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
|
|
179
161
|
let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
|
|
@@ -181,29 +163,26 @@ impl RbEmbedder {
|
|
|
181
163
|
let execution_providers_override =
|
|
182
164
|
if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
|
|
183
165
|
let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
|
|
184
|
-
let max_input_chars_override = if max_input_chars == 0 { None } else { Some(max_input_chars) };
|
|
185
166
|
let overrides = ModelLoadOverrides {
|
|
186
167
|
model_name: name,
|
|
187
168
|
output_tensor: output_override,
|
|
188
169
|
max_length: max_length_override,
|
|
189
170
|
padding: padding_override,
|
|
190
171
|
execution_providers: execution_providers_override,
|
|
191
|
-
|
|
192
|
-
max_input_chars: max_input_chars_override,
|
|
172
|
+
..ModelLoadOverrides::default()
|
|
193
173
|
};
|
|
194
174
|
let embedder = Embedder::from_dir(&dir_path, optimization_level, overrides).map_err(magnus::Error::from)?;
|
|
195
|
-
|
|
196
|
-
Ok(RbEmbedder { inner: Arc::new(embedder), normalize: normalize && !skip_normalize })
|
|
175
|
+
Ok(RbEmbedder { inner: Arc::new(embedder) })
|
|
197
176
|
}
|
|
198
177
|
|
|
199
178
|
pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
|
|
200
179
|
let texts: Vec<String> = texts.to_vec()?;
|
|
201
|
-
let embeddings = infer_without_gvl(&rb_self.inner,
|
|
180
|
+
let embeddings = infer_without_gvl(&rb_self.inner, texts)?;
|
|
202
181
|
tensor_from_array(embeddings)
|
|
203
182
|
}
|
|
204
183
|
|
|
205
184
|
pub fn rb_embed_one(_ruby: &Ruby, rb_self: &Self, text: String) -> Result<RbTensor, Error> {
|
|
206
|
-
let embeddings = infer_without_gvl(&rb_self.inner,
|
|
185
|
+
let embeddings = infer_without_gvl(&rb_self.inner, vec![text])?;
|
|
207
186
|
tensor_from_array(embeddings)
|
|
208
187
|
}
|
|
209
188
|
}
|
|
@@ -219,8 +198,6 @@ impl RbReranker {
|
|
|
219
198
|
max_length: usize,
|
|
220
199
|
padding: String,
|
|
221
200
|
execution_providers: String,
|
|
222
|
-
_lowercase_input: bool,
|
|
223
|
-
_max_input_chars: usize,
|
|
224
201
|
) -> Result<Self, Error> {
|
|
225
202
|
let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
|
|
226
203
|
let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
|
|
@@ -335,12 +312,12 @@ impl RbTensor {
|
|
|
335
312
|
pub fn register(ruby: &Ruby) -> Result<(), Error> {
|
|
336
313
|
let module = ruby.define_module("GTE")?;
|
|
337
314
|
let embedder_class = module.define_class("Embedder", ruby.class_object())?;
|
|
338
|
-
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new,
|
|
315
|
+
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 7))?;
|
|
339
316
|
embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
|
|
340
317
|
embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
|
|
341
318
|
|
|
342
319
|
let reranker_class = module.define_class("Reranker", ruby.class_object())?;
|
|
343
|
-
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new,
|
|
320
|
+
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 8))?;
|
|
344
321
|
reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
|
|
345
322
|
|
|
346
323
|
let tensor_class = module.define_class("Tensor", ruby.class_object())?;
|
data/ext/gte/src/session.rs
CHANGED
|
@@ -3,77 +3,58 @@ use crate::model_config::{ExtractorMode, ModelConfig};
|
|
|
3
3
|
use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
4
4
|
use crate::postprocess::mean_pool;
|
|
5
5
|
use crate::tokenizer::Tokenized;
|
|
6
|
-
use ndarray::{Array2,
|
|
6
|
+
use ndarray::{Array2, ArrayViewD, Ix2};
|
|
7
7
|
use ort::execution_providers::{CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider};
|
|
8
8
|
use ort::session::{OutputSelector, RunOptions, Session};
|
|
9
|
-
use
|
|
10
|
-
use std::
|
|
11
|
-
use std::collections::HashMap;
|
|
12
|
-
use std::path::{Path, PathBuf};
|
|
9
|
+
use parking_lot::Mutex;
|
|
10
|
+
use std::path::Path;
|
|
13
11
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
14
12
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
model_path: PathBuf,
|
|
24
|
-
build_config: ModelConfig,
|
|
25
|
-
}
|
|
26
|
-
|
|
27
|
-
thread_local! {
|
|
28
|
-
static SESSIONS: RefCell<HashMap<usize, Session>> = RefCell::new(HashMap::new());
|
|
13
|
+
pub(crate) fn resolve_pool_size() -> usize {
|
|
14
|
+
if let Some(n) =
|
|
15
|
+
std::env::var("GTE_SESSION_POOL_SIZE").ok().and_then(|v| v.trim().parse::<usize>().ok()).filter(|&n| n > 0)
|
|
16
|
+
{
|
|
17
|
+
return n;
|
|
18
|
+
}
|
|
19
|
+
let cpus = std::thread::available_parallelism().map(std::num::NonZero::get).unwrap_or(2);
|
|
20
|
+
cpus.min(4).max(1)
|
|
29
21
|
}
|
|
30
22
|
|
|
31
23
|
pub struct SessionPool {
|
|
32
|
-
|
|
33
|
-
|
|
24
|
+
sessions: Vec<Mutex<Session>>,
|
|
25
|
+
next_idx: AtomicUsize,
|
|
34
26
|
}
|
|
35
27
|
|
|
36
28
|
impl SessionPool {
|
|
37
|
-
pub fn new(
|
|
38
|
-
let
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
Ok(Self {
|
|
45
|
-
pool_id,
|
|
46
|
-
recipe: SessionRecipe { model_path: model_path.to_path_buf(), build_config: build_config.clone() },
|
|
47
|
-
})
|
|
48
|
-
}
|
|
49
|
-
|
|
50
|
-
pub fn run(&self, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
|
|
51
|
-
self.with_session(|session| run_session(session, tokenized, config))
|
|
29
|
+
pub fn new(model_path: &Path, config: &ModelConfig, pool_size: usize) -> Result<Self> {
|
|
30
|
+
let sessions = (0..pool_size)
|
|
31
|
+
.map(|_| build_session(model_path, config))
|
|
32
|
+
.collect::<Result<Vec<_>>>()?
|
|
33
|
+
.into_iter()
|
|
34
|
+
.map(Mutex::new)
|
|
35
|
+
.collect();
|
|
36
|
+
Ok(Self { sessions, next_idx: AtomicUsize::new(0) })
|
|
52
37
|
}
|
|
53
38
|
|
|
54
39
|
pub fn with_session<F, R>(&self, f: F) -> Result<R>
|
|
55
40
|
where
|
|
56
41
|
F: FnOnce(&mut Session) -> Result<R>,
|
|
57
42
|
{
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
}
|
|
66
|
-
};
|
|
67
|
-
f(session)
|
|
68
|
-
})
|
|
43
|
+
let idx = if self.sessions.len() == 1 {
|
|
44
|
+
0
|
|
45
|
+
} else {
|
|
46
|
+
self.next_idx.fetch_add(1, Ordering::Relaxed) % self.sessions.len()
|
|
47
|
+
};
|
|
48
|
+
let mut session = self.sessions[idx].lock();
|
|
49
|
+
f(&mut session)
|
|
69
50
|
}
|
|
70
|
-
}
|
|
71
51
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
52
|
+
pub fn len(&self) -> usize {
|
|
53
|
+
self.sessions.len()
|
|
54
|
+
}
|
|
55
|
+
}
|
|
75
56
|
|
|
76
|
-
pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
|
|
57
|
+
pub(crate) fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
|
|
77
58
|
fn ort_err(e: impl std::fmt::Display) -> GteError {
|
|
78
59
|
GteError::Ort(e.to_string())
|
|
79
60
|
}
|
|
@@ -109,10 +90,14 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
|
|
|
109
90
|
}
|
|
110
91
|
|
|
111
92
|
fn auto_detect_providers() -> Vec<ExecutionProviderDispatch> {
|
|
112
|
-
let mut providers = Vec::new();
|
|
113
93
|
#[cfg(target_arch = "aarch64")]
|
|
114
|
-
|
|
115
|
-
|
|
94
|
+
{
|
|
95
|
+
vec![XNNPACKExecutionProvider::default().build().fail_silently()]
|
|
96
|
+
}
|
|
97
|
+
#[cfg(not(target_arch = "aarch64"))]
|
|
98
|
+
{
|
|
99
|
+
Vec::new()
|
|
100
|
+
}
|
|
116
101
|
}
|
|
117
102
|
|
|
118
103
|
fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
|
|
@@ -138,11 +123,7 @@ fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionP
|
|
|
138
123
|
providers
|
|
139
124
|
}
|
|
140
125
|
|
|
141
|
-
|
|
142
|
-
// Run a single inference
|
|
143
|
-
// ---------------------------------------------------------------------------
|
|
144
|
-
|
|
145
|
-
pub fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
|
|
126
|
+
pub(crate) fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
|
|
146
127
|
let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
|
|
147
128
|
let run_opts = RunOptions::new()
|
|
148
129
|
.map_err(|e| GteError::Ort(e.to_string()))?
|
|
@@ -156,7 +137,7 @@ pub fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelC
|
|
|
156
137
|
|
|
157
138
|
fn extract_embeddings(
|
|
158
139
|
array: ArrayViewD<'_, f32>,
|
|
159
|
-
attention_mask: ArrayView2<'_, i64>,
|
|
140
|
+
attention_mask: ndarray::ArrayView2<'_, i64>,
|
|
160
141
|
config: &ModelConfig,
|
|
161
142
|
) -> Result<Array2<f32>> {
|
|
162
143
|
match config.mode {
|
|
@@ -189,21 +170,6 @@ mod tests {
|
|
|
189
170
|
|
|
190
171
|
use super::extract_embeddings;
|
|
191
172
|
|
|
192
|
-
fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
|
|
193
|
-
order_override.or(env_order).unwrap_or("cpu").to_ascii_lowercase()
|
|
194
|
-
}
|
|
195
|
-
|
|
196
|
-
fn parse_provider_registrations(order: &str) -> Vec<&str> {
|
|
197
|
-
let mut providers = Vec::new();
|
|
198
|
-
for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
|
|
199
|
-
match provider {
|
|
200
|
-
"xnnpack" | "coreml" => providers.push(provider),
|
|
201
|
-
_ => {}
|
|
202
|
-
}
|
|
203
|
-
}
|
|
204
|
-
providers
|
|
205
|
-
}
|
|
206
|
-
|
|
207
173
|
fn test_config(mode: ExtractorMode) -> ModelConfig {
|
|
208
174
|
ModelConfig {
|
|
209
175
|
max_length: 8,
|
|
@@ -214,8 +180,6 @@ mod tests {
|
|
|
214
180
|
with_attention_mask: true,
|
|
215
181
|
optimization_level: 3,
|
|
216
182
|
execution_providers: None,
|
|
217
|
-
lowercase_input: false,
|
|
218
|
-
max_input_chars: None,
|
|
219
183
|
}
|
|
220
184
|
}
|
|
221
185
|
|
|
@@ -224,37 +188,6 @@ mod tests {
|
|
|
224
188
|
ArrayView2::from_shape((0, 0), &EMPTY).unwrap()
|
|
225
189
|
}
|
|
226
190
|
|
|
227
|
-
#[test]
|
|
228
|
-
fn parse_provider_registrations_keeps_supported_order() {
|
|
229
|
-
let parsed = parse_provider_registrations("xnnpack,coreml");
|
|
230
|
-
assert_eq!(parsed, vec!["xnnpack", "coreml"]);
|
|
231
|
-
}
|
|
232
|
-
|
|
233
|
-
#[test]
|
|
234
|
-
fn parse_provider_registrations_treats_cpu_and_none_as_fallback() {
|
|
235
|
-
assert!(parse_provider_registrations("cpu").is_empty());
|
|
236
|
-
assert!(parse_provider_registrations("none").is_empty());
|
|
237
|
-
assert!(parse_provider_registrations("none,cpu").is_empty());
|
|
238
|
-
}
|
|
239
|
-
|
|
240
|
-
#[test]
|
|
241
|
-
fn parse_provider_registrations_ignores_unknowns_and_empties() {
|
|
242
|
-
let parsed = parse_provider_registrations(" ,xnnpak,,xnnpack,unknown,coreml,");
|
|
243
|
-
assert_eq!(parsed, vec!["xnnpack", "coreml"]);
|
|
244
|
-
}
|
|
245
|
-
|
|
246
|
-
#[test]
|
|
247
|
-
fn resolve_provider_order_prefers_override() {
|
|
248
|
-
assert_eq!(resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")), "xnnpack");
|
|
249
|
-
assert_eq!(resolve_provider_order_with_env(Some("CPU"), None), "cpu");
|
|
250
|
-
}
|
|
251
|
-
|
|
252
|
-
#[test]
|
|
253
|
-
fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
|
|
254
|
-
assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
|
|
255
|
-
assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
|
|
256
|
-
}
|
|
257
|
-
|
|
258
191
|
#[test]
|
|
259
192
|
fn extract_embeddings_raw_copies_only_final_matrix() {
|
|
260
193
|
let output = array![[1.0f32, 2.0], [3.0, 4.0]];
|
|
@@ -287,4 +220,20 @@ mod tests {
|
|
|
287
220
|
|
|
288
221
|
assert_eq!(extracted, expected);
|
|
289
222
|
}
|
|
223
|
+
|
|
224
|
+
#[test]
|
|
225
|
+
fn resolve_pool_size_uses_env_var() {
|
|
226
|
+
std::env::set_var("GTE_SESSION_POOL_SIZE", "16");
|
|
227
|
+
let size = super::resolve_pool_size();
|
|
228
|
+
assert_eq!(size, 16);
|
|
229
|
+
std::env::remove_var("GTE_SESSION_POOL_SIZE");
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
#[test]
|
|
233
|
+
fn resolve_pool_size_defaults_to_cpu_count_capped_at_4() {
|
|
234
|
+
// Without GTE_SESSION_POOL_SIZE, the default is min(available_parallelism, 4).max(1).
|
|
235
|
+
// On any machine with >= 1 CPU, this should return between 1 and 4.
|
|
236
|
+
let size = super::resolve_pool_size();
|
|
237
|
+
assert!((1..=4).contains(&size), "expected 1-4, got {size}");
|
|
238
|
+
}
|
|
290
239
|
}
|
data/ext/gte/src/tokenizer.rs
CHANGED
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
2
|
use crate::model_config::PaddingMode;
|
|
3
|
+
use ndarray::Array2;
|
|
3
4
|
use std::path::Path;
|
|
4
5
|
use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
|
|
5
6
|
|
|
6
7
|
pub struct Tokenized {
|
|
7
|
-
pub
|
|
8
|
-
pub
|
|
9
|
-
pub
|
|
10
|
-
pub attn_masks: Vec<i64>,
|
|
11
|
-
pub type_ids: Option<Vec<i64>>,
|
|
8
|
+
pub input_ids: Array2<i64>,
|
|
9
|
+
pub attn_masks: Array2<i64>,
|
|
10
|
+
pub type_ids: Option<Array2<i64>>,
|
|
12
11
|
}
|
|
13
12
|
|
|
14
13
|
pub struct Tokenizer {
|
|
@@ -24,7 +23,6 @@ impl Tokenizer {
|
|
|
24
23
|
padding_mode: PaddingMode,
|
|
25
24
|
fixed_padding_length: Option<usize>,
|
|
26
25
|
) -> Result<Self> {
|
|
27
|
-
#[allow(unused_results)]
|
|
28
26
|
{
|
|
29
27
|
let mut tokenizer =
|
|
30
28
|
tokenizers::Tokenizer::from_file(tokenizer_path).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
@@ -34,41 +32,59 @@ impl Tokenizer {
|
|
|
34
32
|
strategy: resolve_padding_strategy(padding_mode, max_length, fixed_padding_length),
|
|
35
33
|
..Default::default()
|
|
36
34
|
};
|
|
37
|
-
tokenizer.with_truncation(Some(truncation)).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
38
|
-
tokenizer.with_padding(Some(padding));
|
|
35
|
+
let _ = tokenizer.with_truncation(Some(truncation)).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
36
|
+
let _ = tokenizer.with_padding(Some(padding));
|
|
39
37
|
|
|
40
38
|
Ok(Self { tokenizer, with_type_ids })
|
|
41
39
|
}
|
|
42
40
|
}
|
|
43
41
|
|
|
44
42
|
pub fn tokenize(&self, texts: &[String]) -> Result<Tokenized> {
|
|
45
|
-
if texts.
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
43
|
+
if texts.is_empty() {
|
|
44
|
+
return Ok(Tokenized {
|
|
45
|
+
input_ids: Array2::zeros((0, 0)),
|
|
46
|
+
attn_masks: Array2::zeros((0, 0)),
|
|
47
|
+
type_ids: None,
|
|
48
|
+
});
|
|
49
49
|
}
|
|
50
50
|
|
|
51
51
|
let encode_inputs: Vec<&str> = texts.iter().map(String::as_str).collect();
|
|
52
52
|
let encodings =
|
|
53
53
|
self.tokenizer.encode_batch_fast(encode_inputs, true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
54
54
|
|
|
55
|
-
|
|
55
|
+
build_tokenized(&encodings, self.with_type_ids)
|
|
56
56
|
}
|
|
57
57
|
|
|
58
58
|
pub fn tokenize_pairs(&self, pairs: &[(String, String)]) -> Result<Tokenized> {
|
|
59
|
+
if pairs.is_empty() {
|
|
60
|
+
return Ok(Tokenized {
|
|
61
|
+
input_ids: Array2::zeros((0, 0)),
|
|
62
|
+
attn_masks: Array2::zeros((0, 0)),
|
|
63
|
+
type_ids: None,
|
|
64
|
+
});
|
|
65
|
+
}
|
|
66
|
+
|
|
59
67
|
let encode_inputs: Vec<tokenizers::EncodeInput<'_>> =
|
|
60
68
|
pairs.iter().map(|(left, right)| (left.as_str(), right.as_str()).into()).collect();
|
|
61
69
|
let encodings =
|
|
62
70
|
self.tokenizer.encode_batch_fast(encode_inputs, true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
63
|
-
|
|
71
|
+
build_tokenized(&encodings, self.with_type_ids)
|
|
64
72
|
}
|
|
65
73
|
|
|
66
74
|
pub fn tokenize_query_candidates(&self, query: &str, candidates: &[String]) -> Result<Tokenized> {
|
|
75
|
+
if candidates.is_empty() {
|
|
76
|
+
return Ok(Tokenized {
|
|
77
|
+
input_ids: Array2::zeros((0, 0)),
|
|
78
|
+
attn_masks: Array2::zeros((0, 0)),
|
|
79
|
+
type_ids: None,
|
|
80
|
+
});
|
|
81
|
+
}
|
|
82
|
+
|
|
67
83
|
let encode_inputs: Vec<tokenizers::EncodeInput<'_>> =
|
|
68
84
|
candidates.iter().map(|candidate| (query, candidate.as_str()).into()).collect();
|
|
69
85
|
let encodings =
|
|
70
86
|
self.tokenizer.encode_batch_fast(encode_inputs, true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
71
|
-
|
|
87
|
+
build_tokenized(&encodings, self.with_type_ids)
|
|
72
88
|
}
|
|
73
89
|
}
|
|
74
90
|
|
|
@@ -102,36 +118,30 @@ fn resolve_padding_strategy(
|
|
|
102
118
|
}
|
|
103
119
|
}
|
|
104
120
|
|
|
105
|
-
fn
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&v| i64::from(v)).collect();
|
|
109
|
-
let attn_masks: Vec<i64> = encoding.get_attention_mask().iter().map(|&v| i64::from(v)).collect();
|
|
110
|
-
let type_ids: Option<Vec<i64>> =
|
|
111
|
-
with_type_ids.then(|| encoding.get_type_ids().iter().map(|&v| i64::from(v)).collect());
|
|
112
|
-
|
|
113
|
-
Tokenized { rows: 1, cols, input_ids, attn_masks, type_ids }
|
|
121
|
+
fn to_i64(array: &[u32]) -> Vec<i64> {
|
|
122
|
+
array.iter().map(|&v| v as i64).collect()
|
|
114
123
|
}
|
|
115
124
|
|
|
116
|
-
fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> Tokenized {
|
|
125
|
+
fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> Result<Tokenized> {
|
|
117
126
|
let rows = encodings.len();
|
|
118
127
|
let cols = encodings.first().map_or(0, tokenizers::Encoding::len);
|
|
119
|
-
|
|
128
|
+
if rows == 0 || cols == 0 {
|
|
129
|
+
return Ok(Tokenized { input_ids: Array2::zeros((0, 0)), attn_masks: Array2::zeros((0, 0)), type_ids: None });
|
|
130
|
+
}
|
|
120
131
|
|
|
121
|
-
let mut input_ids =
|
|
122
|
-
let mut attn_masks =
|
|
123
|
-
let mut type_ids = with_type_ids.then(||
|
|
132
|
+
let mut input_ids = Array2::zeros((0, cols));
|
|
133
|
+
let mut attn_masks = Array2::zeros((0, cols));
|
|
134
|
+
let mut type_ids = with_type_ids.then(|| Array2::zeros((0, cols)));
|
|
124
135
|
|
|
125
136
|
for encoding in encodings {
|
|
126
|
-
input_ids.
|
|
127
|
-
attn_masks.
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
type_ids.extend(encoding.get_type_ids().iter().map(|&v| i64::from(v)));
|
|
137
|
+
input_ids.push_row(ndarray::ArrayView::from(&to_i64(encoding.get_ids())))?;
|
|
138
|
+
attn_masks.push_row(ndarray::ArrayView::from(&to_i64(encoding.get_attention_mask())))?;
|
|
139
|
+
if let Some(ref mut type_ids) = type_ids {
|
|
140
|
+
type_ids.push_row(ndarray::ArrayView::from(&to_i64(encoding.get_type_ids())))?;
|
|
131
141
|
}
|
|
132
142
|
}
|
|
133
143
|
|
|
134
|
-
Tokenized {
|
|
144
|
+
Ok(Tokenized { input_ids, attn_masks, type_ids })
|
|
135
145
|
}
|
|
136
146
|
|
|
137
147
|
#[cfg(test)]
|
|
@@ -154,9 +164,6 @@ mod tests {
|
|
|
154
164
|
|
|
155
165
|
#[test]
|
|
156
166
|
fn resolve_padding_strategy_auto_always_uses_batch_longest() {
|
|
157
|
-
// Auto ignores fixed_padding_length from tokenizer.json — BatchLongest is
|
|
158
|
-
// always faster for inference and correct for variable-length inputs.
|
|
159
|
-
// Use PaddingMode::Fixed explicitly when fixed-length padding is required.
|
|
160
167
|
assert!(matches!(resolve_padding_strategy(PaddingMode::Auto, 64, Some(64)), PaddingStrategy::BatchLongest));
|
|
161
168
|
assert!(matches!(resolve_padding_strategy(PaddingMode::Auto, 512, None), PaddingStrategy::BatchLongest));
|
|
162
169
|
}
|