gte 0.0.6 → 0.0.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +16 -8
- data/Rakefile +38 -3
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +4 -4
- data/ext/gte/src/embedder.rs +42 -33
- data/ext/gte/src/model_config.rs +18 -0
- data/ext/gte/src/model_profile.rs +129 -33
- data/ext/gte/src/pipeline.rs +12 -9
- data/ext/gte/src/reranker.rs +49 -31
- data/ext/gte/src/ruby_embedder.rs +73 -113
- data/ext/gte/src/session.rs +279 -15
- data/ext/gte/src/tokenizer.rs +99 -14
- data/ext/gte/tests/inference_integration_test.rs +5 -4
- data/ext/gte/tests/tokenizer_unit_test.rs +5 -2
- data/lib/gte/config.rb +2 -2
- data/lib/gte/embedder.rb +7 -4
- data/lib/gte/reranker.rb +3 -1
- data/lib/gte.rb +1 -10
- metadata +6 -6
data/ext/gte/src/session.rs
CHANGED
|
@@ -3,12 +3,14 @@ use crate::model_config::{ExtractorMode, ModelConfig};
|
|
|
3
3
|
use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
4
4
|
use crate::postprocess::mean_pool;
|
|
5
5
|
use crate::tokenizer::Tokenized;
|
|
6
|
-
use ndarray::{Array2, Ix2};
|
|
6
|
+
use ndarray::{Array2, ArrayView2, ArrayViewD, Ix2};
|
|
7
7
|
use ort::execution_providers::{
|
|
8
8
|
CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider,
|
|
9
9
|
};
|
|
10
10
|
use ort::session::Session;
|
|
11
|
-
use std::path::Path;
|
|
11
|
+
use std::path::{Path, PathBuf};
|
|
12
|
+
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
13
|
+
use std::sync::{Condvar, Mutex};
|
|
12
14
|
|
|
13
15
|
pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
|
|
14
16
|
let opt_level = match config.optimization_level {
|
|
@@ -18,22 +20,176 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
|
|
|
18
20
|
_ => ort::session::builder::GraphOptimizationLevel::Level3,
|
|
19
21
|
};
|
|
20
22
|
|
|
21
|
-
|
|
22
|
-
.
|
|
23
|
-
|
|
23
|
+
fn ort_err(e: impl std::fmt::Display) -> GteError {
|
|
24
|
+
GteError::Ort(e.to_string())
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
let mut builder = Session::builder()
|
|
28
|
+
.map_err(ort_err)?
|
|
29
|
+
.with_optimization_level(opt_level)
|
|
30
|
+
.map_err(ort_err)?
|
|
31
|
+
.with_memory_pattern(true)
|
|
32
|
+
.map_err(ort_err)?;
|
|
24
33
|
|
|
25
34
|
let providers = preferred_execution_providers(config.execution_providers.as_deref());
|
|
26
35
|
if !providers.is_empty() {
|
|
27
|
-
builder = builder
|
|
36
|
+
builder = builder
|
|
37
|
+
.with_execution_providers(providers)
|
|
38
|
+
.map_err(ort_err)?;
|
|
28
39
|
}
|
|
29
40
|
|
|
30
41
|
if config.num_threads > 0 {
|
|
31
|
-
builder = builder
|
|
42
|
+
builder = builder
|
|
43
|
+
.with_intra_threads(config.num_threads)
|
|
44
|
+
.map_err(ort_err)?;
|
|
45
|
+
builder = builder
|
|
46
|
+
.with_inter_threads(config.num_threads)
|
|
47
|
+
.map_err(ort_err)?;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
builder.commit_from_file(model_path).map_err(ort_err)
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// ---------------------------------------------------------------------------
|
|
54
|
+
// Session pool
|
|
55
|
+
// ---------------------------------------------------------------------------
|
|
56
|
+
|
|
57
|
+
const AUTO_THREAD_POOL_CAP: usize = 6;
|
|
58
|
+
|
|
59
|
+
/// Keep enough sessions to cover the configured thread budget without
|
|
60
|
+
/// oversubscribing CPU parallelism. In ORT auto-thread mode (`num_threads == 0`)
|
|
61
|
+
/// we still keep a modest pool because request-level concurrency benefits from
|
|
62
|
+
/// more than one session even when ORT manages thread counts internally.
|
|
63
|
+
fn pool_capacity(num_threads: usize) -> usize {
|
|
64
|
+
let available_parallelism = std::thread::available_parallelism()
|
|
65
|
+
.map(|n| n.get())
|
|
66
|
+
.unwrap_or(1);
|
|
67
|
+
pool_capacity_with_parallelism(num_threads, available_parallelism)
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
fn pool_capacity_with_parallelism(num_threads: usize, available_parallelism: usize) -> usize {
|
|
71
|
+
if available_parallelism == 0 {
|
|
72
|
+
return 1;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
if num_threads == 0 {
|
|
76
|
+
return available_parallelism.clamp(1, AUTO_THREAD_POOL_CAP);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
available_parallelism.div_ceil(num_threads).max(1)
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
pub struct SessionPool {
|
|
83
|
+
sessions: Mutex<Vec<Session>>,
|
|
84
|
+
available: Condvar,
|
|
85
|
+
created: AtomicUsize,
|
|
86
|
+
capacity: usize,
|
|
87
|
+
model_path: PathBuf,
|
|
88
|
+
build_config: ModelConfig,
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
impl SessionPool {
|
|
92
|
+
pub fn new(initial: Session, model_path: PathBuf, build_config: ModelConfig) -> Self {
|
|
93
|
+
let capacity = pool_capacity(build_config.num_threads);
|
|
94
|
+
Self {
|
|
95
|
+
sessions: Mutex::new(vec![initial]),
|
|
96
|
+
available: Condvar::new(),
|
|
97
|
+
created: AtomicUsize::new(1),
|
|
98
|
+
capacity,
|
|
99
|
+
model_path,
|
|
100
|
+
build_config,
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
pub fn acquire(&self) -> Result<PooledSession<'_>> {
|
|
105
|
+
if let Some(session) = self.take_available_session() {
|
|
106
|
+
return Ok(PooledSession {
|
|
107
|
+
pool: self,
|
|
108
|
+
session: Some(session),
|
|
109
|
+
});
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
if let Some(session) = self.try_grow()? {
|
|
113
|
+
return Ok(PooledSession {
|
|
114
|
+
pool: self,
|
|
115
|
+
session: Some(session),
|
|
116
|
+
});
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
let session = self.wait_for_session();
|
|
120
|
+
Ok(PooledSession {
|
|
121
|
+
pool: self,
|
|
122
|
+
session: Some(session),
|
|
123
|
+
})
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
fn release(&self, session: Session) {
|
|
127
|
+
self.sessions.lock().unwrap().push(session);
|
|
128
|
+
self.available.notify_one();
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
fn take_available_session(&self) -> Option<Session> {
|
|
132
|
+
self.sessions.lock().unwrap().pop()
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
fn try_grow(&self) -> Result<Option<Session>> {
|
|
136
|
+
let grew = self
|
|
137
|
+
.created
|
|
138
|
+
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
|
|
139
|
+
(count < self.capacity).then_some(count + 1)
|
|
140
|
+
});
|
|
141
|
+
if grew.is_err() {
|
|
142
|
+
return Ok(None);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
match build_session(&self.model_path, &self.build_config) {
|
|
146
|
+
Ok(session) => Ok(Some(session)),
|
|
147
|
+
Err(error) => {
|
|
148
|
+
self.created.fetch_sub(1, Ordering::AcqRel);
|
|
149
|
+
Err(error)
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
fn wait_for_session(&self) -> Session {
|
|
155
|
+
let mut lock = self.sessions.lock().unwrap();
|
|
156
|
+
loop {
|
|
157
|
+
if let Some(session) = lock.pop() {
|
|
158
|
+
return session;
|
|
159
|
+
}
|
|
160
|
+
lock = self.available.wait(lock).unwrap();
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
pub struct PooledSession<'a> {
|
|
166
|
+
pool: &'a SessionPool,
|
|
167
|
+
session: Option<Session>,
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
impl std::ops::Deref for PooledSession<'_> {
|
|
171
|
+
type Target = Session;
|
|
172
|
+
fn deref(&self) -> &Session {
|
|
173
|
+
self.session.as_ref().unwrap()
|
|
32
174
|
}
|
|
175
|
+
}
|
|
33
176
|
|
|
34
|
-
|
|
177
|
+
impl std::ops::DerefMut for PooledSession<'_> {
|
|
178
|
+
fn deref_mut(&mut self) -> &mut Session {
|
|
179
|
+
self.session.as_mut().unwrap()
|
|
180
|
+
}
|
|
35
181
|
}
|
|
36
182
|
|
|
183
|
+
impl Drop for PooledSession<'_> {
|
|
184
|
+
fn drop(&mut self) {
|
|
185
|
+
if let Some(s) = self.session.take() {
|
|
186
|
+
self.pool.release(s);
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
// ---------------------------------------------------------------------------
|
|
192
|
+
|
|
37
193
|
fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
|
|
38
194
|
let order = resolve_provider_order(order_override);
|
|
39
195
|
|
|
@@ -55,7 +211,10 @@ fn resolve_provider_order(order_override: Option<&str>) -> String {
|
|
|
55
211
|
resolve_provider_order_with_env(order_override, env_order.as_deref())
|
|
56
212
|
}
|
|
57
213
|
|
|
58
|
-
fn resolve_provider_order_with_env(
|
|
214
|
+
fn resolve_provider_order_with_env(
|
|
215
|
+
order_override: Option<&str>,
|
|
216
|
+
env_order: Option<&str>,
|
|
217
|
+
) -> String {
|
|
59
218
|
order_override
|
|
60
219
|
.or(env_order)
|
|
61
220
|
.unwrap_or("cpu")
|
|
@@ -75,14 +234,24 @@ fn parse_provider_registrations(order: &str) -> Vec<&str> {
|
|
|
75
234
|
}
|
|
76
235
|
|
|
77
236
|
pub fn run_session(
|
|
78
|
-
session: &Session,
|
|
237
|
+
session: &mut Session,
|
|
79
238
|
tokenized: &Tokenized,
|
|
80
239
|
config: &ModelConfig,
|
|
81
240
|
) -> Result<Array2<f32>> {
|
|
82
241
|
let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
|
|
83
|
-
let outputs = session
|
|
242
|
+
let outputs = session
|
|
243
|
+
.run(input_tensors.inputs)
|
|
244
|
+
.map_err(|e| GteError::Ort(e.to_string()))?;
|
|
84
245
|
let array = extract_output_tensor(&outputs, config.output_tensor.as_str())?;
|
|
85
246
|
|
|
247
|
+
extract_embeddings(array, input_tensors.attention_mask, config)
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
fn extract_embeddings(
|
|
251
|
+
array: ArrayViewD<'_, f32>,
|
|
252
|
+
attention_mask: ArrayView2<'_, i64>,
|
|
253
|
+
config: &ModelConfig,
|
|
254
|
+
) -> Result<Array2<f32>> {
|
|
86
255
|
match config.mode {
|
|
87
256
|
ExtractorMode::Token(idx) => {
|
|
88
257
|
let shape = array.shape();
|
|
@@ -102,15 +271,43 @@ pub fn run_session(
|
|
|
102
271
|
ndim
|
|
103
272
|
))
|
|
104
273
|
})?;
|
|
105
|
-
mean_pool(hidden_states
|
|
274
|
+
mean_pool(hidden_states, attention_mask)
|
|
106
275
|
}
|
|
107
|
-
ExtractorMode::Raw =>
|
|
276
|
+
ExtractorMode::Raw => array
|
|
277
|
+
.into_dimensionality::<Ix2>()
|
|
278
|
+
.map(|view| view.to_owned())
|
|
279
|
+
.map_err(|e| GteError::Shape(e.to_string())),
|
|
108
280
|
}
|
|
109
281
|
}
|
|
110
282
|
|
|
111
283
|
#[cfg(test)]
|
|
112
284
|
mod tests {
|
|
113
|
-
use
|
|
285
|
+
use crate::model_config::{ExtractorMode, ModelConfig, PaddingMode};
|
|
286
|
+
use ndarray::{array, ArrayView2};
|
|
287
|
+
|
|
288
|
+
use super::{
|
|
289
|
+
extract_embeddings, parse_provider_registrations, pool_capacity_with_parallelism,
|
|
290
|
+
resolve_provider_order_with_env,
|
|
291
|
+
};
|
|
292
|
+
|
|
293
|
+
fn test_config(mode: ExtractorMode) -> ModelConfig {
|
|
294
|
+
ModelConfig {
|
|
295
|
+
max_length: 8,
|
|
296
|
+
padding_mode: PaddingMode::BatchLongest,
|
|
297
|
+
output_tensor: "output".to_string(),
|
|
298
|
+
mode,
|
|
299
|
+
with_type_ids: false,
|
|
300
|
+
with_attention_mask: true,
|
|
301
|
+
num_threads: 1,
|
|
302
|
+
optimization_level: 3,
|
|
303
|
+
execution_providers: None,
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
fn empty_attention_mask() -> ArrayView2<'static, i64> {
|
|
308
|
+
static EMPTY: [i64; 0] = [];
|
|
309
|
+
ArrayView2::from_shape((0, 0), &EMPTY).unwrap()
|
|
310
|
+
}
|
|
114
311
|
|
|
115
312
|
#[test]
|
|
116
313
|
fn parse_provider_registrations_keeps_supported_order() {
|
|
@@ -142,7 +339,74 @@ mod tests {
|
|
|
142
339
|
|
|
143
340
|
#[test]
|
|
144
341
|
fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
|
|
145
|
-
assert_eq!(
|
|
342
|
+
assert_eq!(
|
|
343
|
+
resolve_provider_order_with_env(None, Some("coreml")),
|
|
344
|
+
"coreml"
|
|
345
|
+
);
|
|
146
346
|
assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
|
|
147
347
|
}
|
|
348
|
+
|
|
349
|
+
#[test]
|
|
350
|
+
fn pool_capacity_uses_bounded_parallel_pool_for_auto_thread_mode() {
|
|
351
|
+
assert_eq!(pool_capacity_with_parallelism(0, 1), 1);
|
|
352
|
+
assert_eq!(pool_capacity_with_parallelism(0, 4), 4);
|
|
353
|
+
assert_eq!(pool_capacity_with_parallelism(0, 8), 6);
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
#[test]
|
|
357
|
+
fn pool_capacity_scales_with_available_parallelism() {
|
|
358
|
+
assert_eq!(pool_capacity_with_parallelism(1, 1), 1);
|
|
359
|
+
assert_eq!(pool_capacity_with_parallelism(1, 8), 8);
|
|
360
|
+
assert_eq!(pool_capacity_with_parallelism(2, 8), 4);
|
|
361
|
+
assert_eq!(pool_capacity_with_parallelism(3, 8), 3);
|
|
362
|
+
assert_eq!(pool_capacity_with_parallelism(8, 4), 1);
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
#[test]
|
|
366
|
+
fn extract_embeddings_raw_copies_only_final_matrix() {
|
|
367
|
+
let output = array![[1.0f32, 2.0], [3.0, 4.0]];
|
|
368
|
+
let extracted = extract_embeddings(
|
|
369
|
+
output.view().into_dyn(),
|
|
370
|
+
empty_attention_mask(),
|
|
371
|
+
&test_config(ExtractorMode::Raw),
|
|
372
|
+
)
|
|
373
|
+
.unwrap();
|
|
374
|
+
|
|
375
|
+
assert_eq!(extracted, output);
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
#[test]
|
|
379
|
+
fn extract_embeddings_token_selects_without_copying_full_sequence() {
|
|
380
|
+
let output = array![
|
|
381
|
+
[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
|
|
382
|
+
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
|
|
383
|
+
];
|
|
384
|
+
let expected = array![[3.0f32, 4.0], [9.0, 10.0]];
|
|
385
|
+
let extracted = extract_embeddings(
|
|
386
|
+
output.view().into_dyn(),
|
|
387
|
+
empty_attention_mask(),
|
|
388
|
+
&test_config(ExtractorMode::Token(1)),
|
|
389
|
+
)
|
|
390
|
+
.unwrap();
|
|
391
|
+
|
|
392
|
+
assert_eq!(extracted, expected);
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
#[test]
|
|
396
|
+
fn extract_embeddings_mean_pool_uses_output_view_and_attention_mask() {
|
|
397
|
+
let output = array![
|
|
398
|
+
[[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]],
|
|
399
|
+
[[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]
|
|
400
|
+
];
|
|
401
|
+
let attention_mask = array![[1_i64, 1, 0], [0, 1, 1]];
|
|
402
|
+
let expected = array![[3.0f32, 5.0], [8.0, 10.0]];
|
|
403
|
+
let extracted = extract_embeddings(
|
|
404
|
+
output.view().into_dyn(),
|
|
405
|
+
attention_mask.view(),
|
|
406
|
+
&test_config(ExtractorMode::MeanPool),
|
|
407
|
+
)
|
|
408
|
+
.unwrap();
|
|
409
|
+
|
|
410
|
+
assert_eq!(extracted, expected);
|
|
411
|
+
}
|
|
148
412
|
}
|
data/ext/gte/src/tokenizer.rs
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
|
+
use crate::model_config::PaddingMode;
|
|
2
3
|
use std::path::Path;
|
|
3
4
|
use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
|
|
4
5
|
|
|
@@ -20,6 +21,8 @@ impl Tokenizer {
|
|
|
20
21
|
tokenizer_path: P,
|
|
21
22
|
max_length: usize,
|
|
22
23
|
with_type_ids: bool,
|
|
24
|
+
padding_mode: PaddingMode,
|
|
25
|
+
fixed_padding_length: Option<usize>,
|
|
23
26
|
) -> Result<Self> {
|
|
24
27
|
let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
|
|
25
28
|
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
@@ -33,7 +36,7 @@ impl Tokenizer {
|
|
|
33
36
|
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
34
37
|
|
|
35
38
|
let padding = PaddingParams {
|
|
36
|
-
strategy:
|
|
39
|
+
strategy: resolve_padding_strategy(padding_mode, max_length, fixed_padding_length),
|
|
37
40
|
..Default::default()
|
|
38
41
|
};
|
|
39
42
|
tokenizer.with_padding(Some(padding));
|
|
@@ -73,6 +76,56 @@ impl Tokenizer {
|
|
|
73
76
|
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
74
77
|
build_tokenized(&encodings, self.with_type_ids)
|
|
75
78
|
}
|
|
79
|
+
|
|
80
|
+
pub fn tokenize_query_candidates(&self, query: &str, candidates: &[String]) -> Result<Tokenized> {
|
|
81
|
+
let encode_inputs: Vec<tokenizers::EncodeInput<'_>> = candidates
|
|
82
|
+
.iter()
|
|
83
|
+
.map(|candidate| (query, candidate.as_str()).into())
|
|
84
|
+
.collect();
|
|
85
|
+
let encodings = self
|
|
86
|
+
.tokenizer
|
|
87
|
+
.encode_batch_fast(encode_inputs, true)
|
|
88
|
+
.map_err(|e| GteError::Tokenizer(e.to_string()))?;
|
|
89
|
+
build_tokenized(&encodings, self.with_type_ids)
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
pub fn parse_padding_mode_override(value: Option<&str>) -> Result<Option<PaddingMode>> {
|
|
94
|
+
let Some(raw) = value.map(str::trim).filter(|v| !v.is_empty()) else {
|
|
95
|
+
return Ok(None);
|
|
96
|
+
};
|
|
97
|
+
|
|
98
|
+
let normalized = raw.to_ascii_lowercase().replace('-', "_");
|
|
99
|
+
let parsed = match normalized.as_str() {
|
|
100
|
+
"auto" => PaddingMode::Auto,
|
|
101
|
+
"batch_longest" | "batchlongest" => PaddingMode::BatchLongest,
|
|
102
|
+
"fixed" => PaddingMode::Fixed,
|
|
103
|
+
_ => {
|
|
104
|
+
return Err(GteError::Inference(format!(
|
|
105
|
+
"invalid padding mode '{}'; expected one of: auto, batch_longest, fixed",
|
|
106
|
+
raw
|
|
107
|
+
)))
|
|
108
|
+
}
|
|
109
|
+
};
|
|
110
|
+
Ok(Some(parsed))
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
fn resolve_padding_strategy(
|
|
114
|
+
padding_mode: PaddingMode,
|
|
115
|
+
max_length: usize,
|
|
116
|
+
fixed_padding_length: Option<usize>,
|
|
117
|
+
) -> PaddingStrategy {
|
|
118
|
+
match padding_mode {
|
|
119
|
+
PaddingMode::BatchLongest => PaddingStrategy::BatchLongest,
|
|
120
|
+
PaddingMode::Fixed => PaddingStrategy::Fixed(max_length),
|
|
121
|
+
PaddingMode::Auto => {
|
|
122
|
+
if fixed_padding_length.is_some() {
|
|
123
|
+
PaddingStrategy::Fixed(max_length)
|
|
124
|
+
} else {
|
|
125
|
+
PaddingStrategy::BatchLongest
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
}
|
|
76
129
|
}
|
|
77
130
|
|
|
78
131
|
fn build_tokenized_single(
|
|
@@ -121,21 +174,17 @@ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> R
|
|
|
121
174
|
let mut type_ids = with_type_ids.then(|| Vec::with_capacity(len));
|
|
122
175
|
|
|
123
176
|
for encoding in encodings {
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
);
|
|
177
|
+
for &value in encoding.get_ids() {
|
|
178
|
+
input_ids.push(i64::from(value));
|
|
179
|
+
}
|
|
180
|
+
for &value in encoding.get_attention_mask() {
|
|
181
|
+
attn_masks.push(i64::from(value));
|
|
182
|
+
}
|
|
131
183
|
|
|
132
184
|
if let Some(type_ids) = type_ids.as_mut() {
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
.iter()
|
|
137
|
-
.map(|&value| i64::from(value)),
|
|
138
|
-
);
|
|
185
|
+
for &value in encoding.get_type_ids() {
|
|
186
|
+
type_ids.push(i64::from(value));
|
|
187
|
+
}
|
|
139
188
|
}
|
|
140
189
|
}
|
|
141
190
|
|
|
@@ -147,3 +196,39 @@ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> R
|
|
|
147
196
|
type_ids,
|
|
148
197
|
})
|
|
149
198
|
}
|
|
199
|
+
|
|
200
|
+
#[cfg(test)]
|
|
201
|
+
mod tests {
|
|
202
|
+
use super::{parse_padding_mode_override, resolve_padding_strategy};
|
|
203
|
+
use crate::model_config::PaddingMode;
|
|
204
|
+
use tokenizers::PaddingStrategy;
|
|
205
|
+
|
|
206
|
+
#[test]
|
|
207
|
+
fn parse_padding_mode_override_accepts_expected_values() {
|
|
208
|
+
assert_eq!(
|
|
209
|
+
parse_padding_mode_override(Some("auto")).unwrap(),
|
|
210
|
+
Some(PaddingMode::Auto)
|
|
211
|
+
);
|
|
212
|
+
assert_eq!(
|
|
213
|
+
parse_padding_mode_override(Some("batch-longest")).unwrap(),
|
|
214
|
+
Some(PaddingMode::BatchLongest)
|
|
215
|
+
);
|
|
216
|
+
assert_eq!(
|
|
217
|
+
parse_padding_mode_override(Some("fixed")).unwrap(),
|
|
218
|
+
Some(PaddingMode::Fixed)
|
|
219
|
+
);
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
#[test]
|
|
223
|
+
fn parse_padding_mode_override_rejects_invalid_values() {
|
|
224
|
+
assert!(parse_padding_mode_override(Some("unknown")).is_err());
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
#[test]
|
|
228
|
+
fn resolve_padding_strategy_uses_fixed_for_auto_when_model_has_fixed_padding() {
|
|
229
|
+
match resolve_padding_strategy(PaddingMode::Auto, 64, Some(64)) {
|
|
230
|
+
PaddingStrategy::Fixed(64) => {}
|
|
231
|
+
other => panic!("expected Fixed(64), got {:?}", other),
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
}
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
use gte::embedder::Embedder;
|
|
2
|
+
use gte::model_config::ModelLoadOverrides;
|
|
2
3
|
|
|
3
4
|
#[test]
|
|
4
5
|
#[ignore = "requires ext/gte/tests/fixtures/e5/tokenizer.json and model.onnx"]
|
|
5
6
|
fn test_e5_single_embedding_shape() {
|
|
6
7
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
7
8
|
|
|
8
|
-
let embedder = Embedder::from_dir(DIR, 0, 3,
|
|
9
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
|
|
9
10
|
.expect("embedder should initialize");
|
|
10
11
|
let result = embedder
|
|
11
12
|
.embed(vec!["query: Hello world".to_string()])
|
|
@@ -20,7 +21,7 @@ fn test_e5_single_embedding_shape() {
|
|
|
20
21
|
fn test_clip_single_embedding_shape() {
|
|
21
22
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/clip");
|
|
22
23
|
|
|
23
|
-
let embedder = Embedder::from_dir(DIR, 0, 3,
|
|
24
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
|
|
24
25
|
.expect("embedder should initialize");
|
|
25
26
|
let result = embedder
|
|
26
27
|
.embed(vec!["a photo of a cat".to_string()])
|
|
@@ -35,7 +36,7 @@ fn test_clip_single_embedding_shape() {
|
|
|
35
36
|
fn test_e5_batch_embedding_shape() {
|
|
36
37
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
37
38
|
|
|
38
|
-
let embedder = Embedder::from_dir(DIR, 0, 3,
|
|
39
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
|
|
39
40
|
.expect("embedder should initialize");
|
|
40
41
|
let texts = vec![
|
|
41
42
|
"query: first sentence".to_string(),
|
|
@@ -54,7 +55,7 @@ fn test_e5_batch_embedding_shape() {
|
|
|
54
55
|
fn test_e5_long_input_truncation_no_error() {
|
|
55
56
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
56
57
|
|
|
57
|
-
let embedder = Embedder::from_dir(DIR, 0, 3,
|
|
58
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
|
|
58
59
|
.expect("embedder should initialize");
|
|
59
60
|
let very_long_text = "word ".repeat(1000);
|
|
60
61
|
let result = embedder
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
use gte::model_config::PaddingMode;
|
|
1
2
|
use gte::tokenizer::Tokenizer;
|
|
2
3
|
|
|
3
4
|
#[test]
|
|
@@ -8,7 +9,8 @@ fn test_e5_tokenizer_output_shape() {
|
|
|
8
9
|
"/tests/fixtures/e5/tokenizer.json"
|
|
9
10
|
);
|
|
10
11
|
|
|
11
|
-
let tokenizer = Tokenizer::new(TOKENIZER, 512, true
|
|
12
|
+
let tokenizer = Tokenizer::new(TOKENIZER, 512, true, PaddingMode::BatchLongest, None)
|
|
13
|
+
.expect("tokenizer should load");
|
|
12
14
|
let texts = vec![
|
|
13
15
|
"Hello, world!".to_string(),
|
|
14
16
|
"A second, longer sentence to test padding behavior.".to_string(),
|
|
@@ -33,7 +35,8 @@ fn test_e5_truncation_at_max_length() {
|
|
|
33
35
|
"/tests/fixtures/e5/tokenizer.json"
|
|
34
36
|
);
|
|
35
37
|
|
|
36
|
-
let tokenizer = Tokenizer::new(TOKENIZER, 16, false
|
|
38
|
+
let tokenizer = Tokenizer::new(TOKENIZER, 16, false, PaddingMode::BatchLongest, None)
|
|
39
|
+
.expect("tokenizer should load");
|
|
37
40
|
let long_text = "word ".repeat(200);
|
|
38
41
|
let tokenized = tokenizer
|
|
39
42
|
.tokenize(&[long_text])
|
data/lib/gte/config.rb
CHANGED
|
@@ -4,12 +4,12 @@ module GTE
|
|
|
4
4
|
module Config
|
|
5
5
|
Text = Data.define(
|
|
6
6
|
:model_dir, :threads, :optimization_level,
|
|
7
|
-
:model_name, :normalize, :output_tensor, :max_length, :execution_providers
|
|
7
|
+
:model_name, :normalize, :output_tensor, :max_length, :padding, :execution_providers
|
|
8
8
|
)
|
|
9
9
|
|
|
10
10
|
Reranker = Data.define(
|
|
11
11
|
:model_dir, :threads, :optimization_level,
|
|
12
|
-
:model_name, :sigmoid, :output_tensor, :max_length, :execution_providers
|
|
12
|
+
:model_name, :sigmoid, :output_tensor, :max_length, :padding, :execution_providers
|
|
13
13
|
)
|
|
14
14
|
end
|
|
15
15
|
end
|
data/lib/gte/embedder.rb
CHANGED
|
@@ -2,6 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
module GTE
|
|
4
4
|
class Embedder
|
|
5
|
+
DEFAULT_THREADS = 1
|
|
6
|
+
DEFAULT_OPTIMIZATION_LEVEL = 3
|
|
7
|
+
|
|
5
8
|
class << self
|
|
6
9
|
def config(model_dir)
|
|
7
10
|
cfg = default_config(model_dir)
|
|
@@ -18,21 +21,21 @@ module GTE
|
|
|
18
21
|
config.normalize,
|
|
19
22
|
config.output_tensor.to_s,
|
|
20
23
|
config.max_length || 0,
|
|
24
|
+
config.padding.to_s,
|
|
21
25
|
config.execution_providers.to_s
|
|
22
26
|
)
|
|
23
27
|
end
|
|
24
28
|
|
|
25
|
-
private
|
|
26
|
-
|
|
27
29
|
def default_config(model_dir)
|
|
28
30
|
Config::Text.new(
|
|
29
31
|
model_dir: File.expand_path(model_dir),
|
|
30
|
-
threads:
|
|
31
|
-
optimization_level:
|
|
32
|
+
threads: DEFAULT_THREADS,
|
|
33
|
+
optimization_level: DEFAULT_OPTIMIZATION_LEVEL,
|
|
32
34
|
model_name: nil,
|
|
33
35
|
normalize: true,
|
|
34
36
|
output_tensor: nil,
|
|
35
37
|
max_length: nil,
|
|
38
|
+
padding: nil,
|
|
36
39
|
execution_providers: nil
|
|
37
40
|
)
|
|
38
41
|
end
|
data/lib/gte/reranker.rb
CHANGED
|
@@ -19,12 +19,13 @@ module GTE
|
|
|
19
19
|
def default_config(model_dir)
|
|
20
20
|
Config::Reranker.new(
|
|
21
21
|
model_dir: File.expand_path(model_dir),
|
|
22
|
-
threads:
|
|
22
|
+
threads: 1,
|
|
23
23
|
optimization_level: 3,
|
|
24
24
|
model_name: nil,
|
|
25
25
|
sigmoid: false,
|
|
26
26
|
output_tensor: nil,
|
|
27
27
|
max_length: nil,
|
|
28
|
+
padding: nil,
|
|
28
29
|
execution_providers: nil
|
|
29
30
|
)
|
|
30
31
|
end
|
|
@@ -38,6 +39,7 @@ module GTE
|
|
|
38
39
|
cfg.sigmoid,
|
|
39
40
|
cfg.output_tensor.to_s,
|
|
40
41
|
cfg.max_length || 0,
|
|
42
|
+
cfg.padding.to_s,
|
|
41
43
|
cfg.execution_providers.to_s
|
|
42
44
|
)
|
|
43
45
|
end
|
data/lib/gte.rb
CHANGED
|
@@ -19,16 +19,7 @@ module GTE
|
|
|
19
19
|
|
|
20
20
|
class << self
|
|
21
21
|
def config(model_dir)
|
|
22
|
-
cfg =
|
|
23
|
-
model_dir: File.expand_path(model_dir),
|
|
24
|
-
threads: 3,
|
|
25
|
-
optimization_level: 3,
|
|
26
|
-
model_name: nil,
|
|
27
|
-
normalize: true,
|
|
28
|
-
output_tensor: nil,
|
|
29
|
-
max_length: nil,
|
|
30
|
-
execution_providers: nil
|
|
31
|
-
)
|
|
22
|
+
cfg = Embedder.default_config(model_dir)
|
|
32
23
|
|
|
33
24
|
cfg = yield(cfg) if block_given?
|
|
34
25
|
|