gte 0.0.15-aarch64-linux → 0.0.16-aarch64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,7 @@ use crate::model_profile::{
6
6
  };
7
7
  use crate::pipeline::{extract_output_tensor, InputTensors};
8
8
  use crate::postprocess::sigmoid_scores;
9
- use crate::session::{build_session, SessionPool};
9
+ use crate::session::{build_session, resolve_pool_size, SessionPool};
10
10
  use crate::tokenizer::{parse_padding_mode_override, Tokenizer};
11
11
  use std::path::{Path, PathBuf};
12
12
 
@@ -54,8 +54,6 @@ impl Reranker {
54
54
  with_attention_mask: true,
55
55
  optimization_level,
56
56
  execution_providers: overrides.execution_providers.map(str::to_string),
57
- lowercase_input: false,
58
- max_input_chars: None,
59
57
  };
60
58
  let session = build_session(&model_path, &probe_config)?;
61
59
 
@@ -83,10 +81,8 @@ impl Reranker {
83
81
  with_attention_mask: config.with_attention_mask,
84
82
  optimization_level,
85
83
  execution_providers: None,
86
- lowercase_input: false,
87
- max_input_chars: None,
88
84
  };
89
- let pool = SessionPool::new(session, &model_path, &model_config)?;
85
+ let pool = SessionPool::new(&model_path, &model_config, resolve_pool_size())?;
90
86
  Ok(Self { tokenizer, pool, config })
91
87
  }
92
88
 
@@ -102,13 +98,12 @@ impl Reranker {
102
98
 
103
99
  fn score_tokenized(&self, tokenized: &crate::tokenizer::Tokenized, apply_sigmoid: bool) -> Result<Vec<f32>> {
104
100
  let input_tensors = InputTensors::from_tokenized(tokenized, self.config.with_attention_mask)?;
105
- let output_name = self.config.output_tensor.clone();
106
101
  let inputs = input_tensors.inputs;
107
102
 
108
103
  self.pool.with_session(|session| {
109
104
  let outputs = session.run(inputs).map_err(|e| GteError::Ort(e.to_string()))?;
110
105
 
111
- let array = extract_output_tensor(&outputs, output_name.as_str())?;
106
+ let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
112
107
 
113
108
  let mut scores = match array.ndim() {
114
109
  1 => array.into_dimensionality::<ndarray::Ix1>()?.to_vec(),
@@ -116,14 +111,16 @@ impl Reranker {
116
111
  let shape = array.shape();
117
112
  if shape[1] == 0 {
118
113
  return Err(GteError::Inference(format!(
119
- "reranker output '{output_name}' has invalid shape {shape:?}"
114
+ "reranker output '{}' has invalid shape {shape:?}",
115
+ self.config.output_tensor
120
116
  )));
121
117
  }
122
118
  array.slice(ndarray::s![.., 0]).to_vec()
123
119
  }
124
120
  n => {
125
121
  return Err(GteError::Inference(format!(
126
- "reranker output '{output_name}' rank {n} is unsupported; expected rank 1 or 2"
122
+ "reranker output '{}' rank {n} is unsupported; expected rank 1 or 2",
123
+ self.config.output_tensor
127
124
  )))
128
125
  }
129
126
  };
@@ -3,7 +3,7 @@
3
3
  #![allow(unused_results)]
4
4
  #![allow(unused_qualifications)]
5
5
 
6
- use crate::embedder::{normalize_l2, output_name_suggests_normalized, Embedder};
6
+ use crate::embedder::Embedder;
7
7
  use crate::error::GteError;
8
8
  use crate::model_config::ModelLoadOverrides;
9
9
  use crate::reranker::Reranker;
@@ -15,7 +15,6 @@ use std::sync::Arc;
15
15
  #[wrap(class = "GTE::Embedder", free_immediately, size)]
16
16
  pub struct RbEmbedder {
17
17
  inner: Arc<Embedder>,
18
- normalize: bool,
19
18
  }
20
19
 
21
20
  #[wrap(class = "GTE::Reranker", free_immediately, size)]
@@ -38,7 +37,6 @@ pub struct RbTensor {
38
37
  struct InferArgs {
39
38
  embedder: *const Embedder,
40
39
  texts: *const Vec<String>,
41
- normalize: bool,
42
40
  result: Option<crate::error::Result<ndarray::Array2<f32>>>,
43
41
  }
44
42
 
@@ -66,15 +64,7 @@ fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
66
64
 
67
65
  unsafe extern "C" fn run_embed_without_gvl(ptr: *mut c_void) -> *mut c_void {
68
66
  let args = &mut *(ptr as *mut InferArgs);
69
- let run_result = catch_unwind(AssertUnwindSafe(|| {
70
- // Full embedding path (tokenization + inference) runs without the GVL.
71
- let embeddings = (*args.embedder).embed_ref(&*args.texts)?;
72
- if args.normalize {
73
- Ok(normalize_l2(embeddings))
74
- } else {
75
- Ok(embeddings)
76
- }
77
- }));
67
+ let run_result = catch_unwind(AssertUnwindSafe(|| (*args.embedder).embed(&*args.texts)));
78
68
  args.result = Some(match run_result {
79
69
  Ok(result) => result,
80
70
  Err(payload) => {
@@ -97,14 +87,9 @@ unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
97
87
  std::ptr::null_mut()
98
88
  }
99
89
 
100
- fn infer_without_gvl(
101
- embedder: &Arc<Embedder>,
102
- normalize: bool,
103
- texts: Vec<String>,
104
- ) -> Result<ndarray::Array2<f32>, Error> {
90
+ fn infer_without_gvl(embedder: &Arc<Embedder>, texts: Vec<String>) -> Result<ndarray::Array2<f32>, Error> {
105
91
  let embeddings = unsafe {
106
- let mut args =
107
- InferArgs { embedder: Arc::as_ptr(embedder), texts: &texts as *const Vec<String>, normalize, result: None };
92
+ let mut args = InferArgs { embedder: Arc::as_ptr(embedder), texts: &texts as *const Vec<String>, result: None };
108
93
  rb_sys::rb_thread_call_without_gvl(
109
94
  Some(run_embed_without_gvl),
110
95
  &mut args as *mut InferArgs as *mut c_void,
@@ -167,13 +152,10 @@ impl RbEmbedder {
167
152
  dir_path: String,
168
153
  optimization_level: u8,
169
154
  model_name: String,
170
- normalize: bool,
171
155
  output_tensor: String,
172
156
  max_length: usize,
173
157
  padding: String,
174
158
  execution_providers: String,
175
- lowercase_input: bool,
176
- max_input_chars: usize,
177
159
  ) -> Result<Self, Error> {
178
160
  let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
179
161
  let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
@@ -181,29 +163,26 @@ impl RbEmbedder {
181
163
  let execution_providers_override =
182
164
  if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
183
165
  let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
184
- let max_input_chars_override = if max_input_chars == 0 { None } else { Some(max_input_chars) };
185
166
  let overrides = ModelLoadOverrides {
186
167
  model_name: name,
187
168
  output_tensor: output_override,
188
169
  max_length: max_length_override,
189
170
  padding: padding_override,
190
171
  execution_providers: execution_providers_override,
191
- lowercase_input: Some(lowercase_input),
192
- max_input_chars: max_input_chars_override,
172
+ ..ModelLoadOverrides::default()
193
173
  };
194
174
  let embedder = Embedder::from_dir(&dir_path, optimization_level, overrides).map_err(magnus::Error::from)?;
195
- let skip_normalize = normalize && output_name_suggests_normalized(&embedder.config.output_tensor);
196
- Ok(RbEmbedder { inner: Arc::new(embedder), normalize: normalize && !skip_normalize })
175
+ Ok(RbEmbedder { inner: Arc::new(embedder) })
197
176
  }
198
177
 
199
178
  pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
200
179
  let texts: Vec<String> = texts.to_vec()?;
201
- let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, texts)?;
180
+ let embeddings = infer_without_gvl(&rb_self.inner, texts)?;
202
181
  tensor_from_array(embeddings)
203
182
  }
204
183
 
205
184
  pub fn rb_embed_one(_ruby: &Ruby, rb_self: &Self, text: String) -> Result<RbTensor, Error> {
206
- let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, vec![text])?;
185
+ let embeddings = infer_without_gvl(&rb_self.inner, vec![text])?;
207
186
  tensor_from_array(embeddings)
208
187
  }
209
188
  }
@@ -219,8 +198,6 @@ impl RbReranker {
219
198
  max_length: usize,
220
199
  padding: String,
221
200
  execution_providers: String,
222
- _lowercase_input: bool,
223
- _max_input_chars: usize,
224
201
  ) -> Result<Self, Error> {
225
202
  let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
226
203
  let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
@@ -335,12 +312,12 @@ impl RbTensor {
335
312
  pub fn register(ruby: &Ruby) -> Result<(), Error> {
336
313
  let module = ruby.define_module("GTE")?;
337
314
  let embedder_class = module.define_class("Embedder", ruby.class_object())?;
338
- embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 10))?;
315
+ embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 7))?;
339
316
  embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
340
317
  embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
341
318
 
342
319
  let reranker_class = module.define_class("Reranker", ruby.class_object())?;
343
- reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 10))?;
320
+ reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 8))?;
344
321
  reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
345
322
 
346
323
  let tensor_class = module.define_class("Tensor", ruby.class_object())?;
@@ -3,132 +3,58 @@ use crate::model_config::{ExtractorMode, ModelConfig};
3
3
  use crate::pipeline::{extract_output_tensor, InputTensors};
4
4
  use crate::postprocess::mean_pool;
5
5
  use crate::tokenizer::Tokenized;
6
- use ndarray::{Array2, ArrayView2, ArrayViewD, Ix2};
6
+ use ndarray::{Array2, ArrayViewD, Ix2};
7
7
  use ort::execution_providers::{CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider};
8
8
  use ort::session::{OutputSelector, RunOptions, Session};
9
9
  use parking_lot::Mutex;
10
- use std::path::{Path, PathBuf};
10
+ use std::path::Path;
11
11
  use std::sync::atomic::{AtomicUsize, Ordering};
12
- use std::sync::Arc;
13
12
 
14
- // ---------------------------------------------------------------------------
15
- // Lazy session pool — starts with 1 session, grows on contention, capped.
16
- //
17
- // Pool max is resolved in order:
18
- // 1. GTE_SESSION_POOL_SIZE env var (explicit override)
19
- // 2. Auto: 2 (conservative: 2× pure Ruby memory at peak, no OOM risk)
20
- //
21
- // At idle the pool holds 1 session (same memory as pure Ruby's single
22
- // OnnxRuntime::Model). When all existing sessions are busy and the cap
23
- // hasn't been reached, a new session is created on-demand.
24
- // ---------------------------------------------------------------------------
25
-
26
- fn resolve_pool_cap() -> usize {
13
+ pub(crate) fn resolve_pool_size() -> usize {
27
14
  if let Some(n) =
28
15
  std::env::var("GTE_SESSION_POOL_SIZE").ok().and_then(|v| v.trim().parse::<usize>().ok()).filter(|&n| n > 0)
29
16
  {
30
17
  return n;
31
18
  }
32
- 2
19
+ let cpus = std::thread::available_parallelism().map(std::num::NonZero::get).unwrap_or(2);
20
+ cpus.min(4).max(1)
33
21
  }
34
22
 
35
23
  pub struct SessionPool {
36
- inner: Mutex<PoolInner>,
24
+ sessions: Vec<Mutex<Session>>,
37
25
  next_idx: AtomicUsize,
38
- cap: usize,
39
- }
40
-
41
- struct PoolInner {
42
- sessions: Vec<Arc<Mutex<Session>>>,
43
- model_path: PathBuf,
44
- build_config: ModelConfig,
45
26
  }
46
27
 
47
28
  impl SessionPool {
48
- pub fn new(initial: Session, model_path: &Path, build_config: &ModelConfig) -> Result<Self> {
49
- let cap = resolve_pool_cap();
50
- let sessions = vec![Arc::new(Mutex::new(initial))];
51
-
52
- Ok(Self {
53
- inner: Mutex::new(PoolInner {
54
- sessions,
55
- model_path: model_path.to_path_buf(),
56
- build_config: build_config.clone(),
57
- }),
58
- next_idx: AtomicUsize::new(0),
59
- cap,
60
- })
61
- }
62
-
63
- pub fn run(&self, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
64
- self.with_session(|session| run_session(session, tokenized, config))
29
+ pub fn new(model_path: &Path, config: &ModelConfig, pool_size: usize) -> Result<Self> {
30
+ let sessions = (0..pool_size)
31
+ .map(|_| build_session(model_path, config))
32
+ .collect::<Result<Vec<_>>>()?
33
+ .into_iter()
34
+ .map(Mutex::new)
35
+ .collect();
36
+ Ok(Self { sessions, next_idx: AtomicUsize::new(0) })
65
37
  }
66
38
 
67
39
  pub fn with_session<F, R>(&self, f: F) -> Result<R>
68
40
  where
69
41
  F: FnOnce(&mut Session) -> Result<R>,
70
42
  {
71
- const SPIN_LIMIT: u32 = 64;
72
-
73
- loop {
74
- // Snapshot the pool under the outer lock so the scan below
75
- // doesn't contend on that lock at all.
76
- let arcs: Vec<Arc<Mutex<Session>>> = {
77
- let inner = self.inner.lock();
78
- inner.sessions.clone()
79
- };
80
- let len = arcs.len();
81
- let start = self.next_idx.fetch_add(1, Ordering::Relaxed) % len;
82
-
83
- for offset in 0..len {
84
- let idx = (start + offset) % len;
85
- if let Some(mut guard) = arcs[idx].try_lock() {
86
- return f(&mut guard);
87
- }
88
- }
89
-
90
- // All sessions busy — try to grow the pool
91
- let grew = {
92
- let mut inner = self.inner.lock();
93
- if inner.sessions.len() < self.cap {
94
- match build_session(&inner.model_path, &inner.build_config) {
95
- Ok(session) => {
96
- inner.sessions.push(Arc::new(Mutex::new(session)));
97
- true
98
- }
99
- Err(e) => return Err(e),
100
- }
101
- } else {
102
- false
103
- }
104
- };
105
-
106
- if grew {
107
- continue;
108
- }
109
-
110
- // At cap — spin briefly, then block on a session
111
- let idx = self.next_idx.fetch_add(1, Ordering::Relaxed) % len;
112
- let arc = Arc::clone(&arcs[idx]);
113
-
114
- for _ in 0..SPIN_LIMIT {
115
- if let Some(mut guard) = arc.try_lock() {
116
- return f(&mut guard);
117
- }
118
- std::hint::spin_loop();
119
- }
43
+ let idx = if self.sessions.len() == 1 {
44
+ 0
45
+ } else {
46
+ self.next_idx.fetch_add(1, Ordering::Relaxed) % self.sessions.len()
47
+ };
48
+ let mut session = self.sessions[idx].lock();
49
+ f(&mut session)
50
+ }
120
51
 
121
- let mut guard = arc.lock();
122
- return f(&mut guard);
123
- }
52
+ pub fn len(&self) -> usize {
53
+ self.sessions.len()
124
54
  }
125
55
  }
126
56
 
127
- // ---------------------------------------------------------------------------
128
- // Session construction
129
- // ---------------------------------------------------------------------------
130
-
131
- pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
57
+ pub(crate) fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
132
58
  fn ort_err(e: impl std::fmt::Display) -> GteError {
133
59
  GteError::Ort(e.to_string())
134
60
  }
@@ -164,10 +90,14 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
164
90
  }
165
91
 
166
92
  fn auto_detect_providers() -> Vec<ExecutionProviderDispatch> {
167
- let mut providers = Vec::new();
168
93
  #[cfg(target_arch = "aarch64")]
169
- providers.push(XNNPACKExecutionProvider::default().build().fail_silently());
170
- providers
94
+ {
95
+ vec![XNNPACKExecutionProvider::default().build().fail_silently()]
96
+ }
97
+ #[cfg(not(target_arch = "aarch64"))]
98
+ {
99
+ Vec::new()
100
+ }
171
101
  }
172
102
 
173
103
  fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
@@ -193,11 +123,7 @@ fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionP
193
123
  providers
194
124
  }
195
125
 
196
- // ---------------------------------------------------------------------------
197
- // Run a single inference
198
- // ---------------------------------------------------------------------------
199
-
200
- pub fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
126
+ pub(crate) fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
201
127
  let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
202
128
  let run_opts = RunOptions::new()
203
129
  .map_err(|e| GteError::Ort(e.to_string()))?
@@ -211,7 +137,7 @@ pub fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelC
211
137
 
212
138
  fn extract_embeddings(
213
139
  array: ArrayViewD<'_, f32>,
214
- attention_mask: ArrayView2<'_, i64>,
140
+ attention_mask: ndarray::ArrayView2<'_, i64>,
215
141
  config: &ModelConfig,
216
142
  ) -> Result<Array2<f32>> {
217
143
  match config.mode {
@@ -244,21 +170,6 @@ mod tests {
244
170
 
245
171
  use super::extract_embeddings;
246
172
 
247
- fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
248
- order_override.or(env_order).unwrap_or("cpu").to_ascii_lowercase()
249
- }
250
-
251
- fn parse_provider_registrations(order: &str) -> Vec<&str> {
252
- let mut providers = Vec::new();
253
- for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
254
- match provider {
255
- "xnnpack" | "coreml" => providers.push(provider),
256
- _ => {}
257
- }
258
- }
259
- providers
260
- }
261
-
262
173
  fn test_config(mode: ExtractorMode) -> ModelConfig {
263
174
  ModelConfig {
264
175
  max_length: 8,
@@ -269,8 +180,6 @@ mod tests {
269
180
  with_attention_mask: true,
270
181
  optimization_level: 3,
271
182
  execution_providers: None,
272
- lowercase_input: false,
273
- max_input_chars: None,
274
183
  }
275
184
  }
276
185
 
@@ -279,37 +188,6 @@ mod tests {
279
188
  ArrayView2::from_shape((0, 0), &EMPTY).unwrap()
280
189
  }
281
190
 
282
- #[test]
283
- fn parse_provider_registrations_keeps_supported_order() {
284
- let parsed = parse_provider_registrations("xnnpack,coreml");
285
- assert_eq!(parsed, vec!["xnnpack", "coreml"]);
286
- }
287
-
288
- #[test]
289
- fn parse_provider_registrations_treats_cpu_and_none_as_fallback() {
290
- assert!(parse_provider_registrations("cpu").is_empty());
291
- assert!(parse_provider_registrations("none").is_empty());
292
- assert!(parse_provider_registrations("none,cpu").is_empty());
293
- }
294
-
295
- #[test]
296
- fn parse_provider_registrations_ignores_unknowns_and_empties() {
297
- let parsed = parse_provider_registrations(" ,xnnpak,,xnnpack,unknown,coreml,");
298
- assert_eq!(parsed, vec!["xnnpack", "coreml"]);
299
- }
300
-
301
- #[test]
302
- fn resolve_provider_order_prefers_override() {
303
- assert_eq!(resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")), "xnnpack");
304
- assert_eq!(resolve_provider_order_with_env(Some("CPU"), None), "cpu");
305
- }
306
-
307
- #[test]
308
- fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
309
- assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
310
- assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
311
- }
312
-
313
191
  #[test]
314
192
  fn extract_embeddings_raw_copies_only_final_matrix() {
315
193
  let output = array![[1.0f32, 2.0], [3.0, 4.0]];
@@ -342,4 +220,20 @@ mod tests {
342
220
 
343
221
  assert_eq!(extracted, expected);
344
222
  }
223
+
224
+ #[test]
225
+ fn resolve_pool_size_uses_env_var() {
226
+ std::env::set_var("GTE_SESSION_POOL_SIZE", "16");
227
+ let size = super::resolve_pool_size();
228
+ assert_eq!(size, 16);
229
+ std::env::remove_var("GTE_SESSION_POOL_SIZE");
230
+ }
231
+
232
+ #[test]
233
+ fn resolve_pool_size_defaults_to_cpu_count_capped_at_4() {
234
+ // Without GTE_SESSION_POOL_SIZE, the default is min(available_parallelism, 4).max(1).
235
+ // On any machine with >= 1 CPU, this should return between 1 and 4.
236
+ let size = super::resolve_pool_size();
237
+ assert!((1..=4).contains(&size), "expected 1-4, got {size}");
238
+ }
345
239
  }
@@ -1,14 +1,13 @@
1
1
  use crate::error::{GteError, Result};
2
2
  use crate::model_config::PaddingMode;
3
+ use ndarray::Array2;
3
4
  use std::path::Path;
4
5
  use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
5
6
 
6
7
  pub struct Tokenized {
7
- pub rows: usize,
8
- pub cols: usize,
9
- pub input_ids: Vec<i64>,
10
- pub attn_masks: Vec<i64>,
11
- pub type_ids: Option<Vec<i64>>,
8
+ pub input_ids: Array2<i64>,
9
+ pub attn_masks: Array2<i64>,
10
+ pub type_ids: Option<Array2<i64>>,
12
11
  }
13
12
 
14
13
  pub struct Tokenizer {
@@ -24,7 +23,6 @@ impl Tokenizer {
24
23
  padding_mode: PaddingMode,
25
24
  fixed_padding_length: Option<usize>,
26
25
  ) -> Result<Self> {
27
- #[allow(unused_results)]
28
26
  {
29
27
  let mut tokenizer =
30
28
  tokenizers::Tokenizer::from_file(tokenizer_path).map_err(|e| GteError::Tokenizer(e.to_string()))?;
@@ -34,41 +32,59 @@ impl Tokenizer {
34
32
  strategy: resolve_padding_strategy(padding_mode, max_length, fixed_padding_length),
35
33
  ..Default::default()
36
34
  };
37
- tokenizer.with_truncation(Some(truncation)).map_err(|e| GteError::Tokenizer(e.to_string()))?;
38
- tokenizer.with_padding(Some(padding));
35
+ let _ = tokenizer.with_truncation(Some(truncation)).map_err(|e| GteError::Tokenizer(e.to_string()))?;
36
+ let _ = tokenizer.with_padding(Some(padding));
39
37
 
40
38
  Ok(Self { tokenizer, with_type_ids })
41
39
  }
42
40
  }
43
41
 
44
42
  pub fn tokenize(&self, texts: &[String]) -> Result<Tokenized> {
45
- if texts.len() == 1 {
46
- let encoding =
47
- self.tokenizer.encode_fast(texts[0].as_str(), true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
48
- return Ok(build_tokenized_single(&encoding, self.with_type_ids));
43
+ if texts.is_empty() {
44
+ return Ok(Tokenized {
45
+ input_ids: Array2::zeros((0, 0)),
46
+ attn_masks: Array2::zeros((0, 0)),
47
+ type_ids: None,
48
+ });
49
49
  }
50
50
 
51
51
  let encode_inputs: Vec<&str> = texts.iter().map(String::as_str).collect();
52
52
  let encodings =
53
53
  self.tokenizer.encode_batch_fast(encode_inputs, true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
54
54
 
55
- Ok(build_tokenized(&encodings, self.with_type_ids))
55
+ build_tokenized(&encodings, self.with_type_ids)
56
56
  }
57
57
 
58
58
  pub fn tokenize_pairs(&self, pairs: &[(String, String)]) -> Result<Tokenized> {
59
+ if pairs.is_empty() {
60
+ return Ok(Tokenized {
61
+ input_ids: Array2::zeros((0, 0)),
62
+ attn_masks: Array2::zeros((0, 0)),
63
+ type_ids: None,
64
+ });
65
+ }
66
+
59
67
  let encode_inputs: Vec<tokenizers::EncodeInput<'_>> =
60
68
  pairs.iter().map(|(left, right)| (left.as_str(), right.as_str()).into()).collect();
61
69
  let encodings =
62
70
  self.tokenizer.encode_batch_fast(encode_inputs, true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
63
- Ok(build_tokenized(&encodings, self.with_type_ids))
71
+ build_tokenized(&encodings, self.with_type_ids)
64
72
  }
65
73
 
66
74
  pub fn tokenize_query_candidates(&self, query: &str, candidates: &[String]) -> Result<Tokenized> {
75
+ if candidates.is_empty() {
76
+ return Ok(Tokenized {
77
+ input_ids: Array2::zeros((0, 0)),
78
+ attn_masks: Array2::zeros((0, 0)),
79
+ type_ids: None,
80
+ });
81
+ }
82
+
67
83
  let encode_inputs: Vec<tokenizers::EncodeInput<'_>> =
68
84
  candidates.iter().map(|candidate| (query, candidate.as_str()).into()).collect();
69
85
  let encodings =
70
86
  self.tokenizer.encode_batch_fast(encode_inputs, true).map_err(|e| GteError::Tokenizer(e.to_string()))?;
71
- Ok(build_tokenized(&encodings, self.with_type_ids))
87
+ build_tokenized(&encodings, self.with_type_ids)
72
88
  }
73
89
  }
74
90
 
@@ -102,36 +118,30 @@ fn resolve_padding_strategy(
102
118
  }
103
119
  }
104
120
 
105
- fn build_tokenized_single(encoding: &tokenizers::Encoding, with_type_ids: bool) -> Tokenized {
106
- let cols = encoding.len();
107
-
108
- let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&v| i64::from(v)).collect();
109
- let attn_masks: Vec<i64> = encoding.get_attention_mask().iter().map(|&v| i64::from(v)).collect();
110
- let type_ids: Option<Vec<i64>> =
111
- with_type_ids.then(|| encoding.get_type_ids().iter().map(|&v| i64::from(v)).collect());
112
-
113
- Tokenized { rows: 1, cols, input_ids, attn_masks, type_ids }
121
+ fn to_i64(array: &[u32]) -> Vec<i64> {
122
+ array.iter().map(|&v| v as i64).collect()
114
123
  }
115
124
 
116
- fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> Tokenized {
125
+ fn build_tokenized(encodings: &[tokenizers::Encoding], with_type_ids: bool) -> Result<Tokenized> {
117
126
  let rows = encodings.len();
118
127
  let cols = encodings.first().map_or(0, tokenizers::Encoding::len);
119
- let len = rows * cols;
128
+ if rows == 0 || cols == 0 {
129
+ return Ok(Tokenized { input_ids: Array2::zeros((0, 0)), attn_masks: Array2::zeros((0, 0)), type_ids: None });
130
+ }
120
131
 
121
- let mut input_ids = Vec::with_capacity(len);
122
- let mut attn_masks = Vec::with_capacity(len);
123
- let mut type_ids = with_type_ids.then(|| Vec::with_capacity(len));
132
+ let mut input_ids = Array2::zeros((0, cols));
133
+ let mut attn_masks = Array2::zeros((0, cols));
134
+ let mut type_ids = with_type_ids.then(|| Array2::zeros((0, cols)));
124
135
 
125
136
  for encoding in encodings {
126
- input_ids.extend(encoding.get_ids().iter().map(|&v| i64::from(v)));
127
- attn_masks.extend(encoding.get_attention_mask().iter().map(|&v| i64::from(v)));
128
-
129
- if let Some(type_ids) = type_ids.as_mut() {
130
- type_ids.extend(encoding.get_type_ids().iter().map(|&v| i64::from(v)));
137
+ input_ids.push_row(ndarray::ArrayView::from(&to_i64(encoding.get_ids())))?;
138
+ attn_masks.push_row(ndarray::ArrayView::from(&to_i64(encoding.get_attention_mask())))?;
139
+ if let Some(ref mut type_ids) = type_ids {
140
+ type_ids.push_row(ndarray::ArrayView::from(&to_i64(encoding.get_type_ids())))?;
131
141
  }
132
142
  }
133
143
 
134
- Tokenized { rows, cols, input_ids, attn_masks, type_ids }
144
+ Ok(Tokenized { input_ids, attn_masks, type_ids })
135
145
  }
136
146
 
137
147
  #[cfg(test)]
@@ -154,9 +164,6 @@ mod tests {
154
164
 
155
165
  #[test]
156
166
  fn resolve_padding_strategy_auto_always_uses_batch_longest() {
157
- // Auto ignores fixed_padding_length from tokenizer.json — BatchLongest is
158
- // always faster for inference and correct for variable-length inputs.
159
- // Use PaddingMode::Fixed explicitly when fixed-length padding is required.
160
167
  assert!(matches!(resolve_padding_strategy(PaddingMode::Auto, 64, Some(64)), PaddingStrategy::BatchLongest));
161
168
  assert!(matches!(resolve_padding_strategy(PaddingMode::Auto, 512, None), PaddingStrategy::BatchLongest));
162
169
  }
@@ -1,4 +1,4 @@
1
- use gte::embedder::normalize_l2;
1
+ use gte::postprocess::normalize_l2;
2
2
  use ndarray::array;
3
3
 
4
4
  #[test]