gte 0.0.14-aarch64-linux → 0.0.16-aarch64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,7 @@ use crate::model_profile::{
6
6
  };
7
7
  use crate::pipeline::{extract_output_tensor, InputTensors};
8
8
  use crate::postprocess::sigmoid_scores;
9
- use crate::session::{build_session, SessionPool};
9
+ use crate::session::{build_session, resolve_pool_size, SessionPool};
10
10
  use crate::tokenizer::{parse_padding_mode_override, Tokenizer};
11
11
  use std::path::{Path, PathBuf};
12
12
 
@@ -54,8 +54,6 @@ impl Reranker {
54
54
  with_attention_mask: true,
55
55
  optimization_level,
56
56
  execution_providers: overrides.execution_providers.map(str::to_string),
57
- lowercase_input: false,
58
- max_input_chars: None,
59
57
  };
60
58
  let session = build_session(&model_path, &probe_config)?;
61
59
 
@@ -83,10 +81,8 @@ impl Reranker {
83
81
  with_attention_mask: config.with_attention_mask,
84
82
  optimization_level,
85
83
  execution_providers: None,
86
- lowercase_input: false,
87
- max_input_chars: None,
88
84
  };
89
- let pool = SessionPool::new(session, &model_path, &model_config)?;
85
+ let pool = SessionPool::new(&model_path, &model_config, resolve_pool_size())?;
90
86
  Ok(Self { tokenizer, pool, config })
91
87
  }
92
88
 
@@ -102,13 +98,12 @@ impl Reranker {
102
98
 
103
99
  fn score_tokenized(&self, tokenized: &crate::tokenizer::Tokenized, apply_sigmoid: bool) -> Result<Vec<f32>> {
104
100
  let input_tensors = InputTensors::from_tokenized(tokenized, self.config.with_attention_mask)?;
105
- let output_name = self.config.output_tensor.clone();
106
101
  let inputs = input_tensors.inputs;
107
102
 
108
103
  self.pool.with_session(|session| {
109
104
  let outputs = session.run(inputs).map_err(|e| GteError::Ort(e.to_string()))?;
110
105
 
111
- let array = extract_output_tensor(&outputs, output_name.as_str())?;
106
+ let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
112
107
 
113
108
  let mut scores = match array.ndim() {
114
109
  1 => array.into_dimensionality::<ndarray::Ix1>()?.to_vec(),
@@ -116,14 +111,16 @@ impl Reranker {
116
111
  let shape = array.shape();
117
112
  if shape[1] == 0 {
118
113
  return Err(GteError::Inference(format!(
119
- "reranker output '{output_name}' has invalid shape {shape:?}"
114
+ "reranker output '{}' has invalid shape {shape:?}",
115
+ self.config.output_tensor
120
116
  )));
121
117
  }
122
118
  array.slice(ndarray::s![.., 0]).to_vec()
123
119
  }
124
120
  n => {
125
121
  return Err(GteError::Inference(format!(
126
- "reranker output '{output_name}' rank {n} is unsupported; expected rank 1 or 2"
122
+ "reranker output '{}' rank {n} is unsupported; expected rank 1 or 2",
123
+ self.config.output_tensor
127
124
  )))
128
125
  }
129
126
  };
@@ -3,7 +3,7 @@
3
3
  #![allow(unused_results)]
4
4
  #![allow(unused_qualifications)]
5
5
 
6
- use crate::embedder::{normalize_l2, output_name_suggests_normalized, Embedder};
6
+ use crate::embedder::Embedder;
7
7
  use crate::error::GteError;
8
8
  use crate::model_config::ModelLoadOverrides;
9
9
  use crate::reranker::Reranker;
@@ -15,7 +15,6 @@ use std::sync::Arc;
15
15
  #[wrap(class = "GTE::Embedder", free_immediately, size)]
16
16
  pub struct RbEmbedder {
17
17
  inner: Arc<Embedder>,
18
- normalize: bool,
19
18
  }
20
19
 
21
20
  #[wrap(class = "GTE::Reranker", free_immediately, size)]
@@ -38,7 +37,6 @@ pub struct RbTensor {
38
37
  struct InferArgs {
39
38
  embedder: *const Embedder,
40
39
  texts: *const Vec<String>,
41
- normalize: bool,
42
40
  result: Option<crate::error::Result<ndarray::Array2<f32>>>,
43
41
  }
44
42
 
@@ -66,15 +64,7 @@ fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
66
64
 
67
65
  unsafe extern "C" fn run_embed_without_gvl(ptr: *mut c_void) -> *mut c_void {
68
66
  let args = &mut *(ptr as *mut InferArgs);
69
- let run_result = catch_unwind(AssertUnwindSafe(|| {
70
- // Full embedding path (tokenization + inference) runs without the GVL.
71
- let embeddings = (*args.embedder).embed_ref(&*args.texts)?;
72
- if args.normalize {
73
- Ok(normalize_l2(embeddings))
74
- } else {
75
- Ok(embeddings)
76
- }
77
- }));
67
+ let run_result = catch_unwind(AssertUnwindSafe(|| (*args.embedder).embed(&*args.texts)));
78
68
  args.result = Some(match run_result {
79
69
  Ok(result) => result,
80
70
  Err(payload) => {
@@ -97,14 +87,9 @@ unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
97
87
  std::ptr::null_mut()
98
88
  }
99
89
 
100
- fn infer_without_gvl(
101
- embedder: &Arc<Embedder>,
102
- normalize: bool,
103
- texts: Vec<String>,
104
- ) -> Result<ndarray::Array2<f32>, Error> {
90
+ fn infer_without_gvl(embedder: &Arc<Embedder>, texts: Vec<String>) -> Result<ndarray::Array2<f32>, Error> {
105
91
  let embeddings = unsafe {
106
- let mut args =
107
- InferArgs { embedder: Arc::as_ptr(embedder), texts: &texts as *const Vec<String>, normalize, result: None };
92
+ let mut args = InferArgs { embedder: Arc::as_ptr(embedder), texts: &texts as *const Vec<String>, result: None };
108
93
  rb_sys::rb_thread_call_without_gvl(
109
94
  Some(run_embed_without_gvl),
110
95
  &mut args as *mut InferArgs as *mut c_void,
@@ -167,13 +152,10 @@ impl RbEmbedder {
167
152
  dir_path: String,
168
153
  optimization_level: u8,
169
154
  model_name: String,
170
- normalize: bool,
171
155
  output_tensor: String,
172
156
  max_length: usize,
173
157
  padding: String,
174
158
  execution_providers: String,
175
- lowercase_input: bool,
176
- max_input_chars: usize,
177
159
  ) -> Result<Self, Error> {
178
160
  let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
179
161
  let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
@@ -181,29 +163,26 @@ impl RbEmbedder {
181
163
  let execution_providers_override =
182
164
  if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
183
165
  let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
184
- let max_input_chars_override = if max_input_chars == 0 { None } else { Some(max_input_chars) };
185
166
  let overrides = ModelLoadOverrides {
186
167
  model_name: name,
187
168
  output_tensor: output_override,
188
169
  max_length: max_length_override,
189
170
  padding: padding_override,
190
171
  execution_providers: execution_providers_override,
191
- lowercase_input: Some(lowercase_input),
192
- max_input_chars: max_input_chars_override,
172
+ ..ModelLoadOverrides::default()
193
173
  };
194
174
  let embedder = Embedder::from_dir(&dir_path, optimization_level, overrides).map_err(magnus::Error::from)?;
195
- let skip_normalize = normalize && output_name_suggests_normalized(&embedder.config.output_tensor);
196
- Ok(RbEmbedder { inner: Arc::new(embedder), normalize: normalize && !skip_normalize })
175
+ Ok(RbEmbedder { inner: Arc::new(embedder) })
197
176
  }
198
177
 
199
178
  pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
200
179
  let texts: Vec<String> = texts.to_vec()?;
201
- let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, texts)?;
180
+ let embeddings = infer_without_gvl(&rb_self.inner, texts)?;
202
181
  tensor_from_array(embeddings)
203
182
  }
204
183
 
205
184
  pub fn rb_embed_one(_ruby: &Ruby, rb_self: &Self, text: String) -> Result<RbTensor, Error> {
206
- let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, vec![text])?;
185
+ let embeddings = infer_without_gvl(&rb_self.inner, vec![text])?;
207
186
  tensor_from_array(embeddings)
208
187
  }
209
188
  }
@@ -219,8 +198,6 @@ impl RbReranker {
219
198
  max_length: usize,
220
199
  padding: String,
221
200
  execution_providers: String,
222
- _lowercase_input: bool,
223
- _max_input_chars: usize,
224
201
  ) -> Result<Self, Error> {
225
202
  let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
226
203
  let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
@@ -335,12 +312,12 @@ impl RbTensor {
335
312
  pub fn register(ruby: &Ruby) -> Result<(), Error> {
336
313
  let module = ruby.define_module("GTE")?;
337
314
  let embedder_class = module.define_class("Embedder", ruby.class_object())?;
338
- embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 10))?;
315
+ embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 7))?;
339
316
  embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
340
317
  embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
341
318
 
342
319
  let reranker_class = module.define_class("Reranker", ruby.class_object())?;
343
- reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 10))?;
320
+ reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 8))?;
344
321
  reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
345
322
 
346
323
  let tensor_class = module.define_class("Tensor", ruby.class_object())?;
@@ -3,77 +3,58 @@ use crate::model_config::{ExtractorMode, ModelConfig};
3
3
  use crate::pipeline::{extract_output_tensor, InputTensors};
4
4
  use crate::postprocess::mean_pool;
5
5
  use crate::tokenizer::Tokenized;
6
- use ndarray::{Array2, ArrayView2, ArrayViewD, Ix2};
6
+ use ndarray::{Array2, ArrayViewD, Ix2};
7
7
  use ort::execution_providers::{CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider};
8
8
  use ort::session::{OutputSelector, RunOptions, Session};
9
- use std::cell::RefCell;
10
- use std::collections::hash_map::Entry;
11
- use std::collections::HashMap;
12
- use std::path::{Path, PathBuf};
9
+ use parking_lot::Mutex;
10
+ use std::path::Path;
13
11
  use std::sync::atomic::{AtomicUsize, Ordering};
14
12
 
15
- // ---------------------------------------------------------------------------
16
- // Thread-local session storage — each OS thread lazily creates its own ONNX
17
- // session the first time it calls into a given pool. No Mutex, no contention.
18
- // ---------------------------------------------------------------------------
19
-
20
- static NEXT_POOL_ID: AtomicUsize = AtomicUsize::new(1);
21
-
22
- struct SessionRecipe {
23
- model_path: PathBuf,
24
- build_config: ModelConfig,
25
- }
26
-
27
- thread_local! {
28
- static SESSIONS: RefCell<HashMap<usize, Session>> = RefCell::new(HashMap::new());
13
+ pub(crate) fn resolve_pool_size() -> usize {
14
+ if let Some(n) =
15
+ std::env::var("GTE_SESSION_POOL_SIZE").ok().and_then(|v| v.trim().parse::<usize>().ok()).filter(|&n| n > 0)
16
+ {
17
+ return n;
18
+ }
19
+ let cpus = std::thread::available_parallelism().map(std::num::NonZero::get).unwrap_or(2);
20
+ cpus.min(4).max(1)
29
21
  }
30
22
 
31
23
  pub struct SessionPool {
32
- pool_id: usize,
33
- recipe: SessionRecipe,
24
+ sessions: Vec<Mutex<Session>>,
25
+ next_idx: AtomicUsize,
34
26
  }
35
27
 
36
28
  impl SessionPool {
37
- pub fn new(initial: Session, model_path: &Path, build_config: &ModelConfig) -> Result<Self> {
38
- let pool_id = NEXT_POOL_ID.fetch_add(1, Ordering::Relaxed);
39
-
40
- SESSIONS.with(|map| {
41
- _ = map.borrow_mut().insert(pool_id, initial);
42
- });
43
-
44
- Ok(Self {
45
- pool_id,
46
- recipe: SessionRecipe { model_path: model_path.to_path_buf(), build_config: build_config.clone() },
47
- })
48
- }
49
-
50
- pub fn run(&self, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
51
- self.with_session(|session| run_session(session, tokenized, config))
29
+ pub fn new(model_path: &Path, config: &ModelConfig, pool_size: usize) -> Result<Self> {
30
+ let sessions = (0..pool_size)
31
+ .map(|_| build_session(model_path, config))
32
+ .collect::<Result<Vec<_>>>()?
33
+ .into_iter()
34
+ .map(Mutex::new)
35
+ .collect();
36
+ Ok(Self { sessions, next_idx: AtomicUsize::new(0) })
52
37
  }
53
38
 
54
39
  pub fn with_session<F, R>(&self, f: F) -> Result<R>
55
40
  where
56
41
  F: FnOnce(&mut Session) -> Result<R>,
57
42
  {
58
- SESSIONS.with(|map| {
59
- let mut map = map.borrow_mut();
60
- let session = match map.entry(self.pool_id) {
61
- Entry::Occupied(e) => e.into_mut(),
62
- Entry::Vacant(e) => {
63
- let session = build_session(&self.recipe.model_path, &self.recipe.build_config)?;
64
- e.insert(session)
65
- }
66
- };
67
- f(session)
68
- })
43
+ let idx = if self.sessions.len() == 1 {
44
+ 0
45
+ } else {
46
+ self.next_idx.fetch_add(1, Ordering::Relaxed) % self.sessions.len()
47
+ };
48
+ let mut session = self.sessions[idx].lock();
49
+ f(&mut session)
69
50
  }
70
- }
71
51
 
72
- // ---------------------------------------------------------------------------
73
- // Session construction
74
- // ---------------------------------------------------------------------------
52
+ pub fn len(&self) -> usize {
53
+ self.sessions.len()
54
+ }
55
+ }
75
56
 
76
- pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
57
+ pub(crate) fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
77
58
  fn ort_err(e: impl std::fmt::Display) -> GteError {
78
59
  GteError::Ort(e.to_string())
79
60
  }
@@ -109,10 +90,14 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
109
90
  }
110
91
 
111
92
  fn auto_detect_providers() -> Vec<ExecutionProviderDispatch> {
112
- let mut providers = Vec::new();
113
93
  #[cfg(target_arch = "aarch64")]
114
- providers.push(XNNPACKExecutionProvider::default().build().fail_silently());
115
- providers
94
+ {
95
+ vec![XNNPACKExecutionProvider::default().build().fail_silently()]
96
+ }
97
+ #[cfg(not(target_arch = "aarch64"))]
98
+ {
99
+ Vec::new()
100
+ }
116
101
  }
117
102
 
118
103
  fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
@@ -138,11 +123,7 @@ fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionP
138
123
  providers
139
124
  }
140
125
 
141
- // ---------------------------------------------------------------------------
142
- // Run a single inference
143
- // ---------------------------------------------------------------------------
144
-
145
- pub fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
126
+ pub(crate) fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
146
127
  let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
147
128
  let run_opts = RunOptions::new()
148
129
  .map_err(|e| GteError::Ort(e.to_string()))?
@@ -156,7 +137,7 @@ pub fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelC
156
137
 
157
138
  fn extract_embeddings(
158
139
  array: ArrayViewD<'_, f32>,
159
- attention_mask: ArrayView2<'_, i64>,
140
+ attention_mask: ndarray::ArrayView2<'_, i64>,
160
141
  config: &ModelConfig,
161
142
  ) -> Result<Array2<f32>> {
162
143
  match config.mode {
@@ -189,21 +170,6 @@ mod tests {
189
170
 
190
171
  use super::extract_embeddings;
191
172
 
192
- fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
193
- order_override.or(env_order).unwrap_or("cpu").to_ascii_lowercase()
194
- }
195
-
196
- fn parse_provider_registrations(order: &str) -> Vec<&str> {
197
- let mut providers = Vec::new();
198
- for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
199
- match provider {
200
- "xnnpack" | "coreml" => providers.push(provider),
201
- _ => {}
202
- }
203
- }
204
- providers
205
- }
206
-
207
173
  fn test_config(mode: ExtractorMode) -> ModelConfig {
208
174
  ModelConfig {
209
175
  max_length: 8,
@@ -214,8 +180,6 @@ mod tests {
214
180
  with_attention_mask: true,
215
181
  optimization_level: 3,
216
182
  execution_providers: None,
217
- lowercase_input: false,
218
- max_input_chars: None,
219
183
  }
220
184
  }
221
185
 
@@ -224,37 +188,6 @@ mod tests {
224
188
  ArrayView2::from_shape((0, 0), &EMPTY).unwrap()
225
189
  }
226
190
 
227
- #[test]
228
- fn parse_provider_registrations_keeps_supported_order() {
229
- let parsed = parse_provider_registrations("xnnpack,coreml");
230
- assert_eq!(parsed, vec!["xnnpack", "coreml"]);
231
- }
232
-
233
- #[test]
234
- fn parse_provider_registrations_treats_cpu_and_none_as_fallback() {
235
- assert!(parse_provider_registrations("cpu").is_empty());
236
- assert!(parse_provider_registrations("none").is_empty());
237
- assert!(parse_provider_registrations("none,cpu").is_empty());
238
- }
239
-
240
- #[test]
241
- fn parse_provider_registrations_ignores_unknowns_and_empties() {
242
- let parsed = parse_provider_registrations(" ,xnnpak,,xnnpack,unknown,coreml,");
243
- assert_eq!(parsed, vec!["xnnpack", "coreml"]);
244
- }
245
-
246
- #[test]
247
- fn resolve_provider_order_prefers_override() {
248
- assert_eq!(resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")), "xnnpack");
249
- assert_eq!(resolve_provider_order_with_env(Some("CPU"), None), "cpu");
250
- }
251
-
252
- #[test]
253
- fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
254
- assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
255
- assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
256
- }
257
-
258
191
  #[test]
259
192
  fn extract_embeddings_raw_copies_only_final_matrix() {
260
193
  let output = array![[1.0f32, 2.0], [3.0, 4.0]];
@@ -287,4 +220,20 @@ mod tests {
287
220
 
288
221
  assert_eq!(extracted, expected);
289
222
  }
223
+
224
+ #[test]
225
+ fn resolve_pool_size_uses_env_var() {
226
+ std::env::set_var("GTE_SESSION_POOL_SIZE", "16");
227
+ let size = super::resolve_pool_size();
228
+ assert_eq!(size, 16);
229
+ std::env::remove_var("GTE_SESSION_POOL_SIZE");
230
+ }
231
+
232
+ #[test]
233
+ fn resolve_pool_size_defaults_to_cpu_count_capped_at_4() {
234
+ // Without GTE_SESSION_POOL_SIZE, the default is min(available_parallelism, 4).max(1).
235
+ // On any machine with >= 1 CPU, this should return between 1 and 4.
236
+ let size = super::resolve_pool_size();
237
+ assert!((1..=4).contains(&size), "expected 1-4, got {size}");
238
+ }
290
239
  }
@@ -1,14 +1,13 @@
1
1
  use crate::error::{GteError, Result};
2
2
  use crate::model_config::PaddingMode;
3
+ use ndarray::Array2;
3
4
  use std::path::Path;
4
5
  use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
5
6
 
6
7
  pub struct Tokenized {
7
- pub rows: usize,
8
- pub cols: usize,
9
- pub input_ids: Vec<i64>,
10
- pub attn_masks: Vec<i64>,
11
- pub type_ids: Option<Vec<i64>>,
8
+ pub input_ids: Array2<i64>,
9
+ pub attn_masks: Array2<i64>,
10
+ pub type_ids: Option<Array2<i64>>,
12
11
  }
13
12
 
14
13
  pub struct Tokenizer {
@@ -24,7 +23,6 @@ impl Tokenizer {
24
23
  padding_mode: PaddingMode,
25
24
  fixed_padding_length: Option<usize>,
26
25
  ) -> Result<Self> {
27
- #[allow(unused_results)]
28
26
  {
29
27
  let mut tokenizer =
30
28
  tokenizers::Tokenizer::from_file(tokenizer_path).map_err(|e| GteError::Tokenizer(e.to_string()))?;
@@ -34,41 +32,59 @@ impl Tokenizer {
34
32
  strategy: resolve_padding_strategy(padding_mode, max_length, fixed_padding_length),
35
33
  ..Default::default()
36
34
  };
37
- tokenizer.with_truncation(Some(truncation)).map_err(|e| GteError::Tokenizer(e.to_string()))?;
38
- tokenizer.with_padding(Some(padding));
35
+ let _ = tokenizer.with_truncation(Some(truncation)).map_err(|e| GteError::Tokenizer(e.to_string()))?;
36
+ let _ = tokenizer.with_padding(Some(padding));
39
37
 
40
38
  Ok(Self { tokenizer, with_type_ids })
41
39
  }
42
40
  }
43
41
 
44
42
  pub fn tokenize(&self, texts: &[String]) -> Result<Tokenized> {
45
- if texts.len() == 1 {
46
- let encoding =
47
- self.tokenizer.encode_fast(texts[0].as_str(), true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
48
- return Ok(build_tokenized_single(&encoding, self.with_type_ids));
43
+ if texts.is_empty() {
44
+ return Ok(Tokenized {
45
+ input_ids: Array2::zeros((0, 0)),
46
+ attn_masks: Array2::zeros((0, 0)),
47
+ type_ids: None,
48
+ });
49
49
  }
50
50
 
51
51
  let encode_inputs: Vec<&str> = texts.iter().map(String::as_str).collect();
52
52
  let encodings =
53
53
  self.tokenizer.encode_batch_fast(encode_inputs, true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
54
54
 
55
- Ok(build_tokenized(&encodings, self.with_type_ids))
55
+ build_tokenized(&encodings, self.with_type_ids)
56
56
  }
57
57
 
58
58
  pub fn tokenize_pairs(&self, pairs: &[(String, String)]) -> Result<Tokenized> {
59
+ if pairs.is_empty() {
60
+ return Ok(Tokenized {
61
+ input_ids: Array2::zeros((0, 0)),
62
+ attn_masks: Array2::zeros((0, 0)),
63
+ type_ids: None,
64
+ });
65
+ }
66
+
59
67
  let encode_inputs: Vec<tokenizers::EncodeInput<'_>> =
60
68
  pairs.iter().map(|(left, right)| (left.as_str(), right.as_str()).into()).collect();
61
69
  let encodings =
62
70
  self.tokenizer.encode_batch_fast(encode_inputs, true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
63
- Ok(build_tokenized(&encodings, self.with_type_ids))
71
+ build_tokenized(&encodings, self.with_type_ids)
64
72
  }
65
73
 
66
74
  pub fn tokenize_query_candidates(&self, query: &str, candidates: &[String]) -> Result<Tokenized> {
75
+ if candidates.is_empty() {
76
+ return Ok(Tokenized {
77
+ input_ids: Array2::zeros((0, 0)),
78
+ attn_masks: Array2::zeros((0, 0)),
79
+ type_ids: None,
80
+ });
81
+ }
82
+
67
83
  let encode_inputs: Vec<tokenizers::EncodeInput<'_>> =
68
84
  candidates.iter().map(|candidate| (query, candidate.as_str()).into()).collect();
69
85
  let encodings =
70
86
  self.tokenizer.encode_batch_fast(encode_inputs, true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
71
- Ok(build_tokenized(&encodings, self.with_type_ids))
87
+ build_tokenized(&encodings, self.with_type_ids)
72
88
  }
73
89
  }
74
90
 
@@ -102,36 +118,30 @@ fn resolve_padding_strategy(
102
118
  }
103
119
  }
104
120
 
105
- fn build_tokenized_single(encoding: &tokenizers::Encoding, with_type_ids: bool) -> Tokenized {
106
- let cols = encoding.len();
107
-
108
- let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&v| i64::from(v)).collect();
109
- let attn_masks: Vec<i64> = encoding.get_attention_mask().iter().map(|&v| i64::from(v)).collect();
110
- let type_ids: Option<Vec<i64>> =
111
- with_type_ids.then(|| encoding.get_type_ids().iter().map(|&v| i64::from(v)).collect());
112
-
113
- Tokenized { rows: 1, cols, input_ids, attn_masks, type_ids }
121
+ fn to_i64(array: &[u32]) -> Vec<i64> {
122
+ array.iter().map(|&v| v as i64).collect()
114
123
  }
115
124
 
116
- fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> Tokenized {
125
+ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> Result<Tokenized> {
117
126
  let rows = encodings.len();
118
127
  let cols = encodings.first().map_or(0, tokenizers::Encoding::len);
119
- let len = rows * cols;
128
+ if rows == 0 || cols == 0 {
129
+ return Ok(Tokenized { input_ids: Array2::zeros((0, 0)), attn_masks: Array2::zeros((0, 0)), type_ids: None });
130
+ }
120
131
 
121
- let mut input_ids = Vec::with_capacity(len);
122
- let mut attn_masks = Vec::with_capacity(len);
123
- let mut type_ids = with_type_ids.then(|| Vec::with_capacity(len));
132
+ let mut input_ids = Array2::zeros((0, cols));
133
+ let mut attn_masks = Array2::zeros((0, cols));
134
+ let mut type_ids = with_type_ids.then(|| Array2::zeros((0, cols)));
124
135
 
125
136
  for encoding in encodings {
126
- input_ids.extend(encoding.get_ids().iter().map(|&v| i64::from(v)));
127
- attn_masks.extend(encoding.get_attention_mask().iter().map(|&v| i64::from(v)));
128
-
129
- if let Some(type_ids) = type_ids.as_mut() {
130
- type_ids.extend(encoding.get_type_ids().iter().map(|&v| i64::from(v)));
137
+ input_ids.push_row(ndarray::ArrayView::from(&to_i64(encoding.get_ids())))?;
138
+ attn_masks.push_row(ndarray::ArrayView::from(&to_i64(encoding.get_attention_mask())))?;
139
+ if let Some(ref mut type_ids) = type_ids {
140
+ type_ids.push_row(ndarray::ArrayView::from(&to_i64(encoding.get_type_ids())))?;
131
141
  }
132
142
  }
133
143
 
134
- Tokenized { rows, cols, input_ids, attn_masks, type_ids }
144
+ Ok(Tokenized { input_ids, attn_masks, type_ids })
135
145
  }
136
146
 
137
147
  #[cfg(test)]
@@ -154,9 +164,6 @@ mod tests {
154
164
 
155
165
  #[test]
156
166
  fn resolve_padding_strategy_auto_always_uses_batch_longest() {
157
- // Auto ignores fixed_padding_length from tokenizer.json — BatchLongest is
158
- // always faster for inference and correct for variable-length inputs.
159
- // Use PaddingMode::Fixed explicitly when fixed-length padding is required.
160
167
  assert!(matches!(resolve_padding_strategy(PaddingMode::Auto, 64, Some(64)), PaddingStrategy::BatchLongest));
161
168
  assert!(matches!(resolve_padding_strategy(PaddingMode::Auto, 512, None), PaddingStrategy::BatchLongest));
162
169
  }
@@ -1,4 +1,4 @@
1
- use gte::embedder::normalize_l2;
1
+ use gte::postprocess::normalize_l2;
2
2
  use ndarray::array;
3
3
 
4
4
  #[test]