gte 0.0.6 → 0.0.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,12 +3,14 @@ use crate::model_config::{ExtractorMode, ModelConfig};
3
3
  use crate::pipeline::{extract_output_tensor, InputTensors};
4
4
  use crate::postprocess::mean_pool;
5
5
  use crate::tokenizer::Tokenized;
6
- use ndarray::{Array2, Ix2};
6
+ use ndarray::{Array2, ArrayView2, ArrayViewD, Ix2};
7
7
  use ort::execution_providers::{
8
8
  CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider,
9
9
  };
10
10
  use ort::session::Session;
11
- use std::path::Path;
11
+ use std::path::{Path, PathBuf};
12
+ use std::sync::atomic::{AtomicUsize, Ordering};
13
+ use std::sync::{Condvar, Mutex};
12
14
 
13
15
  pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
14
16
  let opt_level = match config.optimization_level {
@@ -18,22 +20,176 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
18
20
  _ => ort::session::builder::GraphOptimizationLevel::Level3,
19
21
  };
20
22
 
21
- let mut builder = Session::builder()?
22
- .with_optimization_level(opt_level)?
23
- .with_memory_pattern(true)?;
23
+ fn ort_err(e: impl std::fmt::Display) -> GteError {
24
+ GteError::Ort(e.to_string())
25
+ }
26
+
27
+ let mut builder = Session::builder()
28
+ .map_err(ort_err)?
29
+ .with_optimization_level(opt_level)
30
+ .map_err(ort_err)?
31
+ .with_memory_pattern(true)
32
+ .map_err(ort_err)?;
24
33
 
25
34
  let providers = preferred_execution_providers(config.execution_providers.as_deref());
26
35
  if !providers.is_empty() {
27
- builder = builder.with_execution_providers(providers)?;
36
+ builder = builder
37
+ .with_execution_providers(providers)
38
+ .map_err(ort_err)?;
28
39
  }
29
40
 
30
41
  if config.num_threads > 0 {
31
- builder = builder.with_intra_threads(config.num_threads)?;
42
+ builder = builder
43
+ .with_intra_threads(config.num_threads)
44
+ .map_err(ort_err)?;
45
+ builder = builder
46
+ .with_inter_threads(config.num_threads)
47
+ .map_err(ort_err)?;
48
+ }
49
+
50
+ builder.commit_from_file(model_path).map_err(ort_err)
51
+ }
52
+
53
+ // ---------------------------------------------------------------------------
54
+ // Session pool
55
+ // ---------------------------------------------------------------------------
56
+
57
+ const AUTO_THREAD_POOL_CAP: usize = 6;
58
+
59
+ /// Keep enough sessions to cover the configured thread budget without
60
+ /// oversubscribing CPU parallelism. In ORT auto-thread mode (`num_threads == 0`)
61
+ /// we still keep a modest pool because request-level concurrency benefits from
62
+ /// more than one session even when ORT manages thread counts internally.
63
+ fn pool_capacity(num_threads: usize) -> usize {
64
+ let available_parallelism = std::thread::available_parallelism()
65
+ .map(|n| n.get())
66
+ .unwrap_or(1);
67
+ pool_capacity_with_parallelism(num_threads, available_parallelism)
68
+ }
69
+
70
+ fn pool_capacity_with_parallelism(num_threads: usize, available_parallelism: usize) -> usize {
71
+ if available_parallelism == 0 {
72
+ return 1;
73
+ }
74
+
75
+ if num_threads == 0 {
76
+ return available_parallelism.clamp(1, AUTO_THREAD_POOL_CAP);
77
+ }
78
+
79
+ available_parallelism.div_ceil(num_threads).max(1)
80
+ }
81
+
82
+ pub struct SessionPool {
83
+ sessions: Mutex<Vec<Session>>,
84
+ available: Condvar,
85
+ created: AtomicUsize,
86
+ capacity: usize,
87
+ model_path: PathBuf,
88
+ build_config: ModelConfig,
89
+ }
90
+
91
+ impl SessionPool {
92
+ pub fn new(initial: Session, model_path: PathBuf, build_config: ModelConfig) -> Self {
93
+ let capacity = pool_capacity(build_config.num_threads);
94
+ Self {
95
+ sessions: Mutex::new(vec![initial]),
96
+ available: Condvar::new(),
97
+ created: AtomicUsize::new(1),
98
+ capacity,
99
+ model_path,
100
+ build_config,
101
+ }
102
+ }
103
+
104
+ pub fn acquire(&self) -> Result<PooledSession<'_>> {
105
+ if let Some(session) = self.take_available_session() {
106
+ return Ok(PooledSession {
107
+ pool: self,
108
+ session: Some(session),
109
+ });
110
+ }
111
+
112
+ if let Some(session) = self.try_grow()? {
113
+ return Ok(PooledSession {
114
+ pool: self,
115
+ session: Some(session),
116
+ });
117
+ }
118
+
119
+ let session = self.wait_for_session();
120
+ Ok(PooledSession {
121
+ pool: self,
122
+ session: Some(session),
123
+ })
124
+ }
125
+
126
+ fn release(&self, session: Session) {
127
+ self.sessions.lock().unwrap().push(session);
128
+ self.available.notify_one();
129
+ }
130
+
131
+ fn take_available_session(&self) -> Option<Session> {
132
+ self.sessions.lock().unwrap().pop()
133
+ }
134
+
135
+ fn try_grow(&self) -> Result<Option<Session>> {
136
+ let grew = self
137
+ .created
138
+ .fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
139
+ (count < self.capacity).then_some(count + 1)
140
+ });
141
+ if grew.is_err() {
142
+ return Ok(None);
143
+ }
144
+
145
+ match build_session(&self.model_path, &self.build_config) {
146
+ Ok(session) => Ok(Some(session)),
147
+ Err(error) => {
148
+ self.created.fetch_sub(1, Ordering::AcqRel);
149
+ Err(error)
150
+ }
151
+ }
152
+ }
153
+
154
+ fn wait_for_session(&self) -> Session {
155
+ let mut lock = self.sessions.lock().unwrap();
156
+ loop {
157
+ if let Some(session) = lock.pop() {
158
+ return session;
159
+ }
160
+ lock = self.available.wait(lock).unwrap();
161
+ }
162
+ }
163
+ }
164
+
165
+ pub struct PooledSession<'a> {
166
+ pool: &'a SessionPool,
167
+ session: Option<Session>,
168
+ }
169
+
170
+ impl std::ops::Deref for PooledSession<'_> {
171
+ type Target = Session;
172
+ fn deref(&self) -> &Session {
173
+ self.session.as_ref().unwrap()
32
174
  }
175
+ }
33
176
 
34
- Ok(builder.commit_from_file(model_path)?)
177
+ impl std::ops::DerefMut for PooledSession<'_> {
178
+ fn deref_mut(&mut self) -> &mut Session {
179
+ self.session.as_mut().unwrap()
180
+ }
35
181
  }
36
182
 
183
+ impl Drop for PooledSession<'_> {
184
+ fn drop(&mut self) {
185
+ if let Some(s) = self.session.take() {
186
+ self.pool.release(s);
187
+ }
188
+ }
189
+ }
190
+
191
+ // ---------------------------------------------------------------------------
192
+
37
193
  fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
38
194
  let order = resolve_provider_order(order_override);
39
195
 
@@ -55,7 +211,10 @@ fn resolve_provider_order(order_override: Option<&str>) -> String {
55
211
  resolve_provider_order_with_env(order_override, env_order.as_deref())
56
212
  }
57
213
 
58
- fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
214
+ fn resolve_provider_order_with_env(
215
+ order_override: Option<&str>,
216
+ env_order: Option<&str>,
217
+ ) -> String {
59
218
  order_override
60
219
  .or(env_order)
61
220
  .unwrap_or("cpu")
@@ -75,14 +234,24 @@ fn parse_provider_registrations(order: &str) -> Vec<&str> {
75
234
  }
76
235
 
77
236
  pub fn run_session(
78
- session: &Session,
237
+ session: &mut Session,
79
238
  tokenized: &Tokenized,
80
239
  config: &ModelConfig,
81
240
  ) -> Result<Array2<f32>> {
82
241
  let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
83
- let outputs = session.run(input_tensors.inputs)?;
242
+ let outputs = session
243
+ .run(input_tensors.inputs)
244
+ .map_err(|e| GteError::Ort(e.to_string()))?;
84
245
  let array = extract_output_tensor(&outputs, config.output_tensor.as_str())?;
85
246
 
247
+ extract_embeddings(array, input_tensors.attention_mask, config)
248
+ }
249
+
250
+ fn extract_embeddings(
251
+ array: ArrayViewD<'_, f32>,
252
+ attention_mask: ArrayView2<'_, i64>,
253
+ config: &ModelConfig,
254
+ ) -> Result<Array2<f32>> {
86
255
  match config.mode {
87
256
  ExtractorMode::Token(idx) => {
88
257
  let shape = array.shape();
@@ -102,15 +271,43 @@ pub fn run_session(
102
271
  ndim
103
272
  ))
104
273
  })?;
105
- mean_pool(hidden_states.view(), input_tensors.attention_mask)
274
+ mean_pool(hidden_states, attention_mask)
106
275
  }
107
- ExtractorMode::Raw => Ok(array.into_dimensionality::<Ix2>()?.into_owned()),
276
+ ExtractorMode::Raw => array
277
+ .into_dimensionality::<Ix2>()
278
+ .map(|view| view.to_owned())
279
+ .map_err(|e| GteError::Shape(e.to_string())),
108
280
  }
109
281
  }
110
282
 
111
283
  #[cfg(test)]
112
284
  mod tests {
113
- use super::{parse_provider_registrations, resolve_provider_order_with_env};
285
+ use crate::model_config::{ExtractorMode, ModelConfig, PaddingMode};
286
+ use ndarray::{array, ArrayView2};
287
+
288
+ use super::{
289
+ extract_embeddings, parse_provider_registrations, pool_capacity_with_parallelism,
290
+ resolve_provider_order_with_env,
291
+ };
292
+
293
+ fn test_config(mode: ExtractorMode) -> ModelConfig {
294
+ ModelConfig {
295
+ max_length: 8,
296
+ padding_mode: PaddingMode::BatchLongest,
297
+ output_tensor: "output".to_string(),
298
+ mode,
299
+ with_type_ids: false,
300
+ with_attention_mask: true,
301
+ num_threads: 1,
302
+ optimization_level: 3,
303
+ execution_providers: None,
304
+ }
305
+ }
306
+
307
+ fn empty_attention_mask() -> ArrayView2<'static, i64> {
308
+ static EMPTY: [i64; 0] = [];
309
+ ArrayView2::from_shape((0, 0), &EMPTY).unwrap()
310
+ }
114
311
 
115
312
  #[test]
116
313
  fn parse_provider_registrations_keeps_supported_order() {
@@ -142,7 +339,74 @@ mod tests {
142
339
 
143
340
  #[test]
144
341
  fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
145
- assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
342
+ assert_eq!(
343
+ resolve_provider_order_with_env(None, Some("coreml")),
344
+ "coreml"
345
+ );
146
346
  assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
147
347
  }
348
+
349
+ #[test]
350
+ fn pool_capacity_uses_bounded_parallel_pool_for_auto_thread_mode() {
351
+ assert_eq!(pool_capacity_with_parallelism(0, 1), 1);
352
+ assert_eq!(pool_capacity_with_parallelism(0, 4), 4);
353
+ assert_eq!(pool_capacity_with_parallelism(0, 8), 6);
354
+ }
355
+
356
+ #[test]
357
+ fn pool_capacity_scales_with_available_parallelism() {
358
+ assert_eq!(pool_capacity_with_parallelism(1, 1), 1);
359
+ assert_eq!(pool_capacity_with_parallelism(1, 8), 8);
360
+ assert_eq!(pool_capacity_with_parallelism(2, 8), 4);
361
+ assert_eq!(pool_capacity_with_parallelism(3, 8), 3);
362
+ assert_eq!(pool_capacity_with_parallelism(8, 4), 1);
363
+ }
364
+
365
+ #[test]
366
+ fn extract_embeddings_raw_copies_only_final_matrix() {
367
+ let output = array![[1.0f32, 2.0], [3.0, 4.0]];
368
+ let extracted = extract_embeddings(
369
+ output.view().into_dyn(),
370
+ empty_attention_mask(),
371
+ &test_config(ExtractorMode::Raw),
372
+ )
373
+ .unwrap();
374
+
375
+ assert_eq!(extracted, output);
376
+ }
377
+
378
+ #[test]
379
+ fn extract_embeddings_token_selects_without_copying_full_sequence() {
380
+ let output = array![
381
+ [[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
382
+ [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
383
+ ];
384
+ let expected = array![[3.0f32, 4.0], [9.0, 10.0]];
385
+ let extracted = extract_embeddings(
386
+ output.view().into_dyn(),
387
+ empty_attention_mask(),
388
+ &test_config(ExtractorMode::Token(1)),
389
+ )
390
+ .unwrap();
391
+
392
+ assert_eq!(extracted, expected);
393
+ }
394
+
395
+ #[test]
396
+ fn extract_embeddings_mean_pool_uses_output_view_and_attention_mask() {
397
+ let output = array![
398
+ [[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]],
399
+ [[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]
400
+ ];
401
+ let attention_mask = array![[1_i64, 1, 0], [0, 1, 1]];
402
+ let expected = array![[3.0f32, 5.0], [8.0, 10.0]];
403
+ let extracted = extract_embeddings(
404
+ output.view().into_dyn(),
405
+ attention_mask.view(),
406
+ &test_config(ExtractorMode::MeanPool),
407
+ )
408
+ .unwrap();
409
+
410
+ assert_eq!(extracted, expected);
411
+ }
148
412
  }
@@ -1,4 +1,5 @@
1
1
  use crate::error::{GteError, Result};
2
+ use crate::model_config::PaddingMode;
2
3
  use std::path::Path;
3
4
  use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
4
5
 
@@ -20,6 +21,8 @@ impl Tokenizer {
20
21
  tokenizer_path: P,
21
22
  max_length: usize,
22
23
  with_type_ids: bool,
24
+ padding_mode: PaddingMode,
25
+ fixed_padding_length: Option<usize>,
23
26
  ) -> Result<Self> {
24
27
  let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
25
28
  .map_err(|e| GteError::Tokenizer(e.to_string()))?;
@@ -33,7 +36,7 @@ impl Tokenizer {
33
36
  .map_err(|e| GteError::Tokenizer(e.to_string()))?;
34
37
 
35
38
  let padding = PaddingParams {
36
- strategy: PaddingStrategy::BatchLongest,
39
+ strategy: resolve_padding_strategy(padding_mode, max_length, fixed_padding_length),
37
40
  ..Default::default()
38
41
  };
39
42
  tokenizer.with_padding(Some(padding));
@@ -73,6 +76,56 @@ impl Tokenizer {
73
76
  .map_err(|e| GteError::Tokenizer(e.to_string()))?;
74
77
  build_tokenized(&encodings, self.with_type_ids)
75
78
  }
79
+
80
+ pub fn tokenize_query_candidates(&self, query: &str, candidates: &[String]) -> Result<Tokenized> {
81
+ let encode_inputs: Vec<tokenizers::EncodeInput<'_>> = candidates
82
+ .iter()
83
+ .map(|candidate| (query, candidate.as_str()).into())
84
+ .collect();
85
+ let encodings = self
86
+ .tokenizer
87
+ .encode_batch_fast(encode_inputs, true)
88
+ .map_err(|e| GteError::Tokenizer(e.to_string()))?;
89
+ build_tokenized(&encodings, self.with_type_ids)
90
+ }
91
+ }
92
+
93
+ pub fn parse_padding_mode_override(value: Option<&str>) -> Result<Option<PaddingMode>> {
94
+ let Some(raw) = value.map(str::trim).filter(|v| !v.is_empty()) else {
95
+ return Ok(None);
96
+ };
97
+
98
+ let normalized = raw.to_ascii_lowercase().replace('-', "_");
99
+ let parsed = match normalized.as_str() {
100
+ "auto" => PaddingMode::Auto,
101
+ "batch_longest" | "batchlongest" => PaddingMode::BatchLongest,
102
+ "fixed" => PaddingMode::Fixed,
103
+ _ => {
104
+ return Err(GteError::Inference(format!(
105
+ "invalid padding mode '{}'; expected one of: auto, batch_longest, fixed",
106
+ raw
107
+ )))
108
+ }
109
+ };
110
+ Ok(Some(parsed))
111
+ }
112
+
113
+ fn resolve_padding_strategy(
114
+ padding_mode: PaddingMode,
115
+ max_length: usize,
116
+ fixed_padding_length: Option<usize>,
117
+ ) -> PaddingStrategy {
118
+ match padding_mode {
119
+ PaddingMode::BatchLongest => PaddingStrategy::BatchLongest,
120
+ PaddingMode::Fixed => PaddingStrategy::Fixed(max_length),
121
+ PaddingMode::Auto => {
122
+ if fixed_padding_length.is_some() {
123
+ PaddingStrategy::Fixed(max_length)
124
+ } else {
125
+ PaddingStrategy::BatchLongest
126
+ }
127
+ }
128
+ }
76
129
  }
77
130
 
78
131
  fn build_tokenized_single(
@@ -121,21 +174,17 @@ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> R
121
174
  let mut type_ids = with_type_ids.then(|| Vec::with_capacity(len));
122
175
 
123
176
  for encoding in encodings {
124
- input_ids.extend(encoding.get_ids().iter().map(|&value| i64::from(value)));
125
- attn_masks.extend(
126
- encoding
127
- .get_attention_mask()
128
- .iter()
129
- .map(|&value| i64::from(value)),
130
- );
177
+ for &value in encoding.get_ids() {
178
+ input_ids.push(i64::from(value));
179
+ }
180
+ for &value in encoding.get_attention_mask() {
181
+ attn_masks.push(i64::from(value));
182
+ }
131
183
 
132
184
  if let Some(type_ids) = type_ids.as_mut() {
133
- type_ids.extend(
134
- encoding
135
- .get_type_ids()
136
- .iter()
137
- .map(|&value| i64::from(value)),
138
- );
185
+ for &value in encoding.get_type_ids() {
186
+ type_ids.push(i64::from(value));
187
+ }
139
188
  }
140
189
  }
141
190
 
@@ -147,3 +196,39 @@ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> R
147
196
  type_ids,
148
197
  })
149
198
  }
199
+
200
+ #[cfg(test)]
201
+ mod tests {
202
+ use super::{parse_padding_mode_override, resolve_padding_strategy};
203
+ use crate::model_config::PaddingMode;
204
+ use tokenizers::PaddingStrategy;
205
+
206
+ #[test]
207
+ fn parse_padding_mode_override_accepts_expected_values() {
208
+ assert_eq!(
209
+ parse_padding_mode_override(Some("auto")).unwrap(),
210
+ Some(PaddingMode::Auto)
211
+ );
212
+ assert_eq!(
213
+ parse_padding_mode_override(Some("batch-longest")).unwrap(),
214
+ Some(PaddingMode::BatchLongest)
215
+ );
216
+ assert_eq!(
217
+ parse_padding_mode_override(Some("fixed")).unwrap(),
218
+ Some(PaddingMode::Fixed)
219
+ );
220
+ }
221
+
222
+ #[test]
223
+ fn parse_padding_mode_override_rejects_invalid_values() {
224
+ assert!(parse_padding_mode_override(Some("unknown")).is_err());
225
+ }
226
+
227
+ #[test]
228
+ fn resolve_padding_strategy_uses_fixed_for_auto_when_model_has_fixed_padding() {
229
+ match resolve_padding_strategy(PaddingMode::Auto, 64, Some(64)) {
230
+ PaddingStrategy::Fixed(64) => {}
231
+ other => panic!("expected Fixed(64), got {:?}", other),
232
+ }
233
+ }
234
+ }
@@ -1,11 +1,12 @@
1
1
  use gte::embedder::Embedder;
2
+ use gte::model_config::ModelLoadOverrides;
2
3
 
3
4
  #[test]
4
5
  #[ignore = "requires ext/gte/tests/fixtures/e5/tokenizer.json and model.onnx"]
5
6
  fn test_e5_single_embedding_shape() {
6
7
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
7
8
 
8
- let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
9
+ let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
9
10
  .expect("embedder should initialize");
10
11
  let result = embedder
11
12
  .embed(vec!["query: Hello world".to_string()])
@@ -20,7 +21,7 @@ fn test_e5_single_embedding_shape() {
20
21
  fn test_clip_single_embedding_shape() {
21
22
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/clip");
22
23
 
23
- let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
24
+ let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
24
25
  .expect("embedder should initialize");
25
26
  let result = embedder
26
27
  .embed(vec!["a photo of a cat".to_string()])
@@ -35,7 +36,7 @@ fn test_clip_single_embedding_shape() {
35
36
  fn test_e5_batch_embedding_shape() {
36
37
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
37
38
 
38
- let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
39
+ let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
39
40
  .expect("embedder should initialize");
40
41
  let texts = vec![
41
42
  "query: first sentence".to_string(),
@@ -54,7 +55,7 @@ fn test_e5_batch_embedding_shape() {
54
55
  fn test_e5_long_input_truncation_no_error() {
55
56
  const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
56
57
 
57
- let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
58
+ let embedder = Embedder::from_dir(DIR, 0, 3, ModelLoadOverrides::default())
58
59
  .expect("embedder should initialize");
59
60
  let very_long_text = "word ".repeat(1000);
60
61
  let result = embedder
@@ -1,3 +1,4 @@
1
+ use gte::model_config::PaddingMode;
1
2
  use gte::tokenizer::Tokenizer;
2
3
 
3
4
  #[test]
@@ -8,7 +9,8 @@ fn test_e5_tokenizer_output_shape() {
8
9
  "/tests/fixtures/e5/tokenizer.json"
9
10
  );
10
11
 
11
- let tokenizer = Tokenizer::new(TOKENIZER, 512, true).expect("tokenizer should load");
12
+ let tokenizer = Tokenizer::new(TOKENIZER, 512, true, PaddingMode::BatchLongest, None)
13
+ .expect("tokenizer should load");
12
14
  let texts = vec![
13
15
  "Hello, world!".to_string(),
14
16
  "A second, longer sentence to test padding behavior.".to_string(),
@@ -33,7 +35,8 @@ fn test_e5_truncation_at_max_length() {
33
35
  "/tests/fixtures/e5/tokenizer.json"
34
36
  );
35
37
 
36
- let tokenizer = Tokenizer::new(TOKENIZER, 16, false).expect("tokenizer should load");
38
+ let tokenizer = Tokenizer::new(TOKENIZER, 16, false, PaddingMode::BatchLongest, None)
39
+ .expect("tokenizer should load");
37
40
  let long_text = "word ".repeat(200);
38
41
  let tokenized = tokenizer
39
42
  .tokenize(&[long_text])
data/lib/gte/config.rb CHANGED
@@ -4,12 +4,12 @@ module GTE
4
4
  module Config
5
5
  Text = Data.define(
6
6
  :model_dir, :threads, :optimization_level,
7
- :model_name, :normalize, :output_tensor, :max_length, :execution_providers
7
+ :model_name, :normalize, :output_tensor, :max_length, :padding, :execution_providers
8
8
  )
9
9
 
10
10
  Reranker = Data.define(
11
11
  :model_dir, :threads, :optimization_level,
12
- :model_name, :sigmoid, :output_tensor, :max_length, :execution_providers
12
+ :model_name, :sigmoid, :output_tensor, :max_length, :padding, :execution_providers
13
13
  )
14
14
  end
15
15
  end
data/lib/gte/embedder.rb CHANGED
@@ -2,6 +2,9 @@
2
2
 
3
3
  module GTE
4
4
  class Embedder
5
+ DEFAULT_THREADS = 1
6
+ DEFAULT_OPTIMIZATION_LEVEL = 3
7
+
5
8
  class << self
6
9
  def config(model_dir)
7
10
  cfg = default_config(model_dir)
@@ -18,21 +21,21 @@ module GTE
18
21
  config.normalize,
19
22
  config.output_tensor.to_s,
20
23
  config.max_length || 0,
24
+ config.padding.to_s,
21
25
  config.execution_providers.to_s
22
26
  )
23
27
  end
24
28
 
25
- private
26
-
27
29
  def default_config(model_dir)
28
30
  Config::Text.new(
29
31
  model_dir: File.expand_path(model_dir),
30
- threads: 3,
31
- optimization_level: 3,
32
+ threads: DEFAULT_THREADS,
33
+ optimization_level: DEFAULT_OPTIMIZATION_LEVEL,
32
34
  model_name: nil,
33
35
  normalize: true,
34
36
  output_tensor: nil,
35
37
  max_length: nil,
38
+ padding: nil,
36
39
  execution_providers: nil
37
40
  )
38
41
  end
data/lib/gte/reranker.rb CHANGED
@@ -19,12 +19,13 @@ module GTE
19
19
  def default_config(model_dir)
20
20
  Config::Reranker.new(
21
21
  model_dir: File.expand_path(model_dir),
22
- threads: 3,
22
+ threads: 1,
23
23
  optimization_level: 3,
24
24
  model_name: nil,
25
25
  sigmoid: false,
26
26
  output_tensor: nil,
27
27
  max_length: nil,
28
+ padding: nil,
28
29
  execution_providers: nil
29
30
  )
30
31
  end
@@ -38,6 +39,7 @@ module GTE
38
39
  cfg.sigmoid,
39
40
  cfg.output_tensor.to_s,
40
41
  cfg.max_length || 0,
42
+ cfg.padding.to_s,
41
43
  cfg.execution_providers.to_s
42
44
  )
43
45
  end
data/lib/gte.rb CHANGED
@@ -19,16 +19,7 @@ module GTE
19
19
 
20
20
  class << self
21
21
  def config(model_dir)
22
- cfg = Config::Text.new(
23
- model_dir: File.expand_path(model_dir),
24
- threads: 3,
25
- optimization_level: 3,
26
- model_name: nil,
27
- normalize: true,
28
- output_tensor: nil,
29
- max_length: nil,
30
- execution_providers: nil
31
- )
22
+ cfg = Embedder.default_config(model_dir)
32
23
 
33
24
  cfg = yield(cfg) if block_given?
34
25